File size: 3,870 Bytes
10b4a5f
 
358ab8f
10b4a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358ab8f
 
10b4a5f
358ab8f
 
 
 
 
10b4a5f
 
 
 
 
358ab8f
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
 
 
 
10b4a5f
358ab8f
 
 
 
 
10b4a5f
358ab8f
 
10b4a5f
 
 
 
 
358ab8f
10b4a5f
358ab8f
 
 
 
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
 
 
 
10b4a5f
 
 
 
 
358ab8f
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
10b4a5f
358ab8f
 
 
 
 
10b4a5f
 
 
358ab8f
10b4a5f
358ab8f
10b4a5f
 
358ab8f
10b4a5f
 
 
 
 
 
 
 
358ab8f
10b4a5f
358ab8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse


def str2bool(v):
    return v.lower() in ("true", "1")


arg_lists = []
parser = argparse.ArgumentParser()


def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg


# -----------------------------------------------------------------------------
# Network
net_arg = add_argument_group("Network")
net_arg.add_argument(
    "--model_name", type=str, default="SGM", help="" "model for training"
)
net_arg.add_argument(
    "--config_path",
    type=str,
    default="configs/sgm.yaml",
    help="" "config path for model",
)

# -----------------------------------------------------------------------------
# Data
data_arg = add_argument_group("Data")
data_arg.add_argument(
    "--rawdata_path", type=str, default="rawdata", help="" "path for rawdata"
)
data_arg.add_argument(
    "--dataset_path", type=str, default="dataset", help="" "path for dataset"
)
data_arg.add_argument(
    "--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir"
)
data_arg.add_argument(
    "--num_kpt", type=int, default=1000, help="" "number of kpt for training"
)
data_arg.add_argument(
    "--input_normalize",
    type=str,
    default="img",
    help="" "normalize type for input kpt, img or intrinsic",
)
data_arg.add_argument(
    "--data_aug",
    type=str2bool,
    default=True,
    help="" "apply kpt coordinate homography augmentation",
)
data_arg.add_argument(
    "--desc_suffix", type=str, default="suffix", help="" "desc file suffix"
)


# -----------------------------------------------------------------------------
# Loss
loss_arg = add_argument_group("loss")
loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum")
loss_arg.add_argument(
    "--seed_loss_weight",
    type=float,
    default=250,
    help="" "confidence loss weight for sgm",
)
loss_arg.add_argument(
    "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm"
)
loss_arg.add_argument(
    "--inlier_th",
    type=float,
    default=5e-3,
    help="" "inlier threshold for epipolar distance (for sgm and visualization)",
)


# -----------------------------------------------------------------------------
# Training
train_arg = add_argument_group("Train")
train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate")
train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size")
train_arg.add_argument(
    "--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES"
)
train_arg.add_argument(
    "--train_iter", type=int, default=1000000, help="" "training iterations to perform"
)
train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path")
train_arg.add_argument(
    "--val_intv", type=int, default=20000, help="" "validation interval"
)
train_arg.add_argument(
    "--save_intv", type=int, default=1000, help="" "summary interval"
)
train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval")
train_arg.add_argument(
    "--decay_rate", type=float, default=0.999996, help="" "lr decay rate"
)
train_arg.add_argument(
    "--decay_iter", type=float, default=300000, help="" "lr decay iter"
)
train_arg.add_argument(
    "--local_rank", type=int, default=0, help="" "local rank for ddp"
)
train_arg.add_argument(
    "--train_vis_folder",
    type=str,
    default=".",
    help="" "visualization folder during training",
)

# -----------------------------------------------------------------------------
# Visualization
vis_arg = add_argument_group("Visualization")
vis_arg.add_argument(
    "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar"
)


def get_config():
    config, unparsed = parser.parse_known_args()
    return config, unparsed


def print_usage():
    parser.print_usage()


#
# config.py ends here