|
import argparse |
|
|
|
arg_lists = [] |
|
parser = argparse.ArgumentParser(description="LANet") |
|
|
|
|
|
def str2bool(v): |
|
return v.lower() in ("true", "1") |
|
|
|
|
|
def add_argument_group(name): |
|
arg = parser.add_argument_group(name) |
|
arg_lists.append(arg) |
|
return arg |
|
|
|
|
|
|
|
traindata_arg = add_argument_group("Traindata Params") |
|
traindata_arg.add_argument("--train_txt", type=str, default="", help="Train set.") |
|
traindata_arg.add_argument( |
|
"--train_root", type=str, default="", help="Where the train images are." |
|
) |
|
traindata_arg.add_argument( |
|
"--batch_size", type=int, default=8, help="# of images in each batch of data" |
|
) |
|
traindata_arg.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=4, |
|
help="# of subprocesses to use for data loading", |
|
) |
|
traindata_arg.add_argument( |
|
"--pin_memory", |
|
type=str2bool, |
|
default=True, |
|
help="# of subprocesses to use for data loading", |
|
) |
|
traindata_arg.add_argument( |
|
"--shuffle", |
|
type=str2bool, |
|
default=True, |
|
help="Whether to shuffle the train and valid indices", |
|
) |
|
traindata_arg.add_argument("--image_shape", type=tuple, default=(240, 320), help="") |
|
traindata_arg.add_argument( |
|
"--jittering", type=tuple, default=(0.5, 0.5, 0.2, 0.05), help="" |
|
) |
|
|
|
|
|
storage_arg = add_argument_group("Storage") |
|
storage_arg.add_argument("--ckpt_name", type=str, default="PointModel", help="") |
|
|
|
|
|
train_arg = add_argument_group("Training Params") |
|
train_arg.add_argument("--start_epoch", type=int, default=0, help="") |
|
train_arg.add_argument("--max_epoch", type=int, default=12, help="") |
|
train_arg.add_argument( |
|
"--init_lr", type=float, default=3e-4, help="Initial learning rate value." |
|
) |
|
train_arg.add_argument( |
|
"--lr_factor", type=float, default=0.5, help="Reduce learning rate value." |
|
) |
|
train_arg.add_argument( |
|
"--momentum", type=float, default=0.9, help="Nesterov momentum value." |
|
) |
|
train_arg.add_argument("--display", type=int, default=50, help="") |
|
|
|
|
|
loss_arg = add_argument_group("Loss function Params") |
|
loss_arg.add_argument("--score_weight", type=float, default=1.0, help="") |
|
loss_arg.add_argument("--loc_weight", type=float, default=1.0, help="") |
|
loss_arg.add_argument("--desc_weight", type=float, default=4.0, help="") |
|
loss_arg.add_argument("--corres_weight", type=float, default=0.5, help="") |
|
loss_arg.add_argument("--corres_threshold", type=int, default=4.0, help="") |
|
|
|
|
|
misc_arg = add_argument_group("Misc.") |
|
misc_arg.add_argument( |
|
"--use_gpu", type=str2bool, default=True, help="Whether to run on the GPU." |
|
) |
|
misc_arg.add_argument("--gpu", type=int, default=0, help="Which GPU to run on.") |
|
misc_arg.add_argument( |
|
"--seed", type=int, default=1001, help="Seed to ensure reproducibility." |
|
) |
|
misc_arg.add_argument( |
|
"--ckpt_dir", |
|
type=str, |
|
default="./checkpoints", |
|
help="Directory in which to save model checkpoints.", |
|
) |
|
|
|
|
|
def get_config(): |
|
config, unparsed = parser.parse_known_args() |
|
return config, unparsed |
|
|