Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
3.89 kB
import argparse
def str2bool(v):
return v.lower() in ("true", "1")
arg_lists = []
parser = argparse.ArgumentParser()
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# -----------------------------------------------------------------------------
# Network
net_arg = add_argument_group("Network")
net_arg.add_argument(
"--model_name", type=str,default='SGM', help=""
"model for training")
net_arg.add_argument(
"--config_path", type=str,default='configs/sgm.yaml', help=""
"config path for model")
# -----------------------------------------------------------------------------
# Data
data_arg = add_argument_group("Data")
data_arg.add_argument(
"--rawdata_path", type=str, default='rawdata', help=""
"path for rawdata")
data_arg.add_argument(
"--dataset_path", type=str, default='dataset', help=""
"path for dataset")
data_arg.add_argument(
"--desc_path", type=str, default='desc', help=""
"path for descriptor(kpt) dir")
data_arg.add_argument(
"--num_kpt", type=int, default=1000, help=""
"number of kpt for training")
data_arg.add_argument(
"--input_normalize", type=str, default='img', help=""
"normalize type for input kpt, img or intrinsic")
data_arg.add_argument(
"--data_aug", type=str2bool, default=True, help=""
"apply kpt coordinate homography augmentation")
data_arg.add_argument(
"--desc_suffix", type=str, default='suffix', help=""
"desc file suffix")
# -----------------------------------------------------------------------------
# Loss
loss_arg = add_argument_group("loss")
loss_arg.add_argument(
"--momentum", type=float, default=0.9, help=""
"momentum")
loss_arg.add_argument(
"--seed_loss_weight", type=float, default=250, help=""
"confidence loss weight for sgm")
loss_arg.add_argument(
"--mid_loss_weight", type=float, default=1, help=""
"midseeding loss weight for sgm")
loss_arg.add_argument(
"--inlier_th", type=float, default=5e-3, help=""
"inlier threshold for epipolar distance (for sgm and visualization)")
# -----------------------------------------------------------------------------
# Training
train_arg = add_argument_group("Train")
train_arg.add_argument(
"--train_lr", type=float, default=1e-4, help=""
"learning rate")
train_arg.add_argument(
"--train_batch_size", type=int, default=16, help=""
"batch size")
train_arg.add_argument(
"--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES')
train_arg.add_argument(
"--train_iter", type=int, default=1000000, help=""
"training iterations to perform")
train_arg.add_argument(
"--log_base", type=str, default="./log/", help=""
"log path")
train_arg.add_argument(
"--val_intv", type=int, default=20000, help=""
"validation interval")
train_arg.add_argument(
"--save_intv", type=int, default=1000, help=""
"summary interval")
train_arg.add_argument(
"--log_intv", type=int, default=100, help=""
"log interval")
train_arg.add_argument(
"--decay_rate", type=float, default=0.999996, help=""
"lr decay rate")
train_arg.add_argument(
"--decay_iter", type=float, default=300000, help=""
"lr decay iter")
train_arg.add_argument(
"--local_rank", type=int, default=0, help=""
"local rank for ddp")
train_arg.add_argument(
"--train_vis_folder", type=str, default='.', help=""
"visualization folder during training")
# -----------------------------------------------------------------------------
# Visualization
vis_arg = add_argument_group('Visualization')
vis_arg.add_argument(
"--tqdm_width", type=int, default=79, help=""
"width of the tqdm bar"
)
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed
def print_usage():
parser.print_usage()
#
# config.py ends here