|
import argparse |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description='LAVT training and testing') |
|
parser.add_argument('--amsgrad', action='store_true', |
|
help='if true, set amsgrad to True in an Adam or AdamW optimizer.') |
|
parser.add_argument('-b', '--batch-size', default=8, type=int) |
|
parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') |
|
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') |
|
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog') |
|
parser.add_argument('--ddp_trained_weights', action='store_true', |
|
help='Only needs specified when testing,' |
|
'whether the weights to be loaded are from a DDP-trained model') |
|
parser.add_argument('--device', default='cuda:0', help='device') |
|
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') |
|
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') |
|
parser.add_argument('--img_size', default=480, type=int, help='input image size') |
|
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') |
|
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate') |
|
parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' |
|
'where a, b, c, and d refer to the numbers of heads in stage-1,' |
|
'stage-2, stage-3, and stage-4 PWAMs') |
|
parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one') |
|
parser.add_argument('--model_id', default='lavt', help='name to identify the model') |
|
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') |
|
parser.add_argument('--pin_mem', action='store_true', |
|
help='If true, pin memory when using the data loader.') |
|
parser.add_argument('--pretrained_swin_weights', default='', |
|
help='path to pre-trained Swin backbone weights') |
|
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') |
|
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory') |
|
parser.add_argument('--resume', default='auto', help='resume from checkpoint') |
|
parser.add_argument('--split', default='test', help='only used when testing') |
|
parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)') |
|
parser.add_argument('--swin_type', default='base', |
|
help='tiny, small, base, or large variants of the Swin Transformer') |
|
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', |
|
dest='weight_decay') |
|
parser.add_argument('--window12', action='store_true', |
|
help='only needs specified when testing,' |
|
'when training, window size is inferred from pre-trained weights file name' |
|
'(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') |
|
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers') |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--max_ckpt', default=2, type=int) |
|
parser.add_argument('--num_object_queries', default=1, type=int) |
|
parser.add_argument('--no_object_weight', default=0.0, type=float) |
|
parser.add_argument('--class_weight', default=2.0, type=float) |
|
parser.add_argument('--dice_weight', default=2.0, type=float) |
|
parser.add_argument('--mask_weight', default=2.0, type=float) |
|
parser.add_argument('--train_num_points', default=12544, type=int) |
|
parser.add_argument('--dim_feedforward', default=2048, type=int) |
|
parser.add_argument('--dec_layers', default=10, type=int) |
|
parser.add_argument('--transformer_enc_layers', default=4, type=int) |
|
|
|
parser.add_argument('--plic_pos_weight', default=0.5, type=float) |
|
parser.add_argument('--plic_neg_weight', default=0.5, type=float) |
|
parser.add_argument('--plic_lang_weight', default=0.5, type=float) |
|
parser.add_argument('--plic_pos_alpha', default=0.0, type=float) |
|
parser.add_argument('--plic_neg_alpha', default=0.0, type=float) |
|
parser.add_argument('--plic_lang_alpha', default=0.0, type=float) |
|
parser.add_argument('--plic_pos_temp', default=0.2, type=float) |
|
parser.add_argument('--plic_neg_temp', default=0.2, type=float) |
|
parser.add_argument('--plic_lang_temp', default=0.2, type=float) |
|
parser.add_argument('--smlm_weight', default=1.0, type=float) |
|
parser.add_argument('--vis_dir', default='./vis_dir') |
|
|
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = get_parser() |
|
args_dict = parser.parse_args() |
|
|