File size: 3,890 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse

def str2bool(v):
    return v.lower() in ("true", "1")


arg_lists = []
parser = argparse.ArgumentParser()


def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg


# -----------------------------------------------------------------------------
# Network
net_arg = add_argument_group("Network")
net_arg.add_argument(
    "--model_name", type=str,default='SGM', help=""
    "model for training")
net_arg.add_argument(
    "--config_path", type=str,default='configs/sgm.yaml', help=""
    "config path for model")

# -----------------------------------------------------------------------------
# Data
data_arg = add_argument_group("Data")
data_arg.add_argument(
    "--rawdata_path", type=str, default='rawdata', help=""
    "path for rawdata")
data_arg.add_argument(
    "--dataset_path", type=str, default='dataset', help=""
    "path for dataset")
data_arg.add_argument(
    "--desc_path", type=str, default='desc', help=""
    "path for descriptor(kpt) dir")
data_arg.add_argument(
    "--num_kpt", type=int, default=1000, help=""
    "number of kpt for training")
data_arg.add_argument(
    "--input_normalize", type=str, default='img', help=""
    "normalize type for input kpt, img or intrinsic")
data_arg.add_argument(
    "--data_aug", type=str2bool, default=True, help=""
    "apply kpt coordinate homography augmentation")
data_arg.add_argument(
    "--desc_suffix", type=str, default='suffix', help=""
    "desc file suffix")


# -----------------------------------------------------------------------------
# Loss
loss_arg = add_argument_group("loss")
loss_arg.add_argument(
    "--momentum", type=float, default=0.9, help=""
    "momentum")
loss_arg.add_argument(
    "--seed_loss_weight", type=float, default=250, help=""
    "confidence loss weight for sgm")
loss_arg.add_argument(
    "--mid_loss_weight", type=float, default=1, help=""
    "midseeding loss weight for sgm")
loss_arg.add_argument(
    "--inlier_th", type=float, default=5e-3, help=""
    "inlier threshold for epipolar distance (for sgm and visualization)")


# -----------------------------------------------------------------------------
# Training
train_arg = add_argument_group("Train")
train_arg.add_argument(
    "--train_lr", type=float, default=1e-4, help=""
    "learning rate")
train_arg.add_argument(
    "--train_batch_size", type=int, default=16, help=""
    "batch size")
train_arg.add_argument(
    "--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES')
train_arg.add_argument(
    "--train_iter", type=int, default=1000000, help=""
    "training iterations to perform")
train_arg.add_argument(
    "--log_base", type=str, default="./log/", help=""
    "log path")
train_arg.add_argument(
    "--val_intv", type=int, default=20000, help=""
    "validation interval")
train_arg.add_argument(
    "--save_intv", type=int, default=1000, help=""
    "summary interval")
train_arg.add_argument(
    "--log_intv", type=int, default=100, help=""
    "log interval")
train_arg.add_argument(
    "--decay_rate", type=float, default=0.999996, help=""
    "lr decay rate")
train_arg.add_argument(
    "--decay_iter", type=float, default=300000, help=""
    "lr decay iter")
train_arg.add_argument(
    "--local_rank", type=int, default=0, help=""
    "local rank for ddp")
train_arg.add_argument(
    "--train_vis_folder", type=str, default='.', help=""
    "visualization folder during training")

# -----------------------------------------------------------------------------
# Visualization
vis_arg = add_argument_group('Visualization')
vis_arg.add_argument(
    "--tqdm_width", type=int, default=79, help=""
    "width of the tqdm bar"
)

def get_config():
    config, unparsed = parser.parse_known_args()
    return config, unparsed


def print_usage():
    parser.print_usage()

#
# config.py ends here