File size: 3,870 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import argparse
def str2bool(v):
return v.lower() in ("true", "1")
arg_lists = []
parser = argparse.ArgumentParser()
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# -----------------------------------------------------------------------------
# Network
net_arg = add_argument_group("Network")
net_arg.add_argument(
"--model_name", type=str, default="SGM", help="" "model for training"
)
net_arg.add_argument(
"--config_path",
type=str,
default="configs/sgm.yaml",
help="" "config path for model",
)
# -----------------------------------------------------------------------------
# Data
data_arg = add_argument_group("Data")
data_arg.add_argument(
"--rawdata_path", type=str, default="rawdata", help="" "path for rawdata"
)
data_arg.add_argument(
"--dataset_path", type=str, default="dataset", help="" "path for dataset"
)
data_arg.add_argument(
"--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir"
)
data_arg.add_argument(
"--num_kpt", type=int, default=1000, help="" "number of kpt for training"
)
data_arg.add_argument(
"--input_normalize",
type=str,
default="img",
help="" "normalize type for input kpt, img or intrinsic",
)
data_arg.add_argument(
"--data_aug",
type=str2bool,
default=True,
help="" "apply kpt coordinate homography augmentation",
)
data_arg.add_argument(
"--desc_suffix", type=str, default="suffix", help="" "desc file suffix"
)
# -----------------------------------------------------------------------------
# Loss
loss_arg = add_argument_group("loss")
loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum")
loss_arg.add_argument(
"--seed_loss_weight",
type=float,
default=250,
help="" "confidence loss weight for sgm",
)
loss_arg.add_argument(
"--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm"
)
loss_arg.add_argument(
"--inlier_th",
type=float,
default=5e-3,
help="" "inlier threshold for epipolar distance (for sgm and visualization)",
)
# -----------------------------------------------------------------------------
# Training
train_arg = add_argument_group("Train")
train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate")
train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size")
train_arg.add_argument(
"--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES"
)
train_arg.add_argument(
"--train_iter", type=int, default=1000000, help="" "training iterations to perform"
)
train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path")
train_arg.add_argument(
"--val_intv", type=int, default=20000, help="" "validation interval"
)
train_arg.add_argument(
"--save_intv", type=int, default=1000, help="" "summary interval"
)
train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval")
train_arg.add_argument(
"--decay_rate", type=float, default=0.999996, help="" "lr decay rate"
)
train_arg.add_argument(
"--decay_iter", type=float, default=300000, help="" "lr decay iter"
)
train_arg.add_argument(
"--local_rank", type=int, default=0, help="" "local rank for ddp"
)
train_arg.add_argument(
"--train_vis_folder",
type=str,
default=".",
help="" "visualization folder during training",
)
# -----------------------------------------------------------------------------
# Visualization
vis_arg = add_argument_group("Visualization")
vis_arg.add_argument(
"--tqdm_width", type=int, default=79, help="" "width of the tqdm bar"
)
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed
def print_usage():
parser.print_usage()
#
# config.py ends here
|