File size: 3,519 Bytes
437b5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse

arg_lists = []
parser = argparse.ArgumentParser(description='LANet')

def str2bool(v):
    return v.lower() in ('true', '1')

def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg

# train data params
traindata_arg = add_argument_group('Traindata Params')
traindata_arg.add_argument('--train_txt', type=str, default='',
                            help='Train set.')
traindata_arg.add_argument('--train_root', type=str, default='',
                            help='Where the train images are.')
traindata_arg.add_argument('--batch_size', type=int, default=8,
                            help='# of images in each batch of data')
traindata_arg.add_argument('--num_workers', type=int, default=4,
                            help='# of subprocesses to use for data loading')
traindata_arg.add_argument('--pin_memory', type=str2bool, default=True,
                            help='# of subprocesses to use for data loading')
traindata_arg.add_argument('--shuffle', type=str2bool, default=True,
                            help='Whether to shuffle the train and valid indices')
traindata_arg.add_argument('--image_shape', type=tuple, default=(240, 320),
                            help='')
traindata_arg.add_argument('--jittering', type=tuple, default=(0.5, 0.5, 0.2, 0.05),
                            help='')

# data storage
storage_arg = add_argument_group('Storage')
storage_arg.add_argument('--ckpt_name', type=str, default='PointModel',
                            help='')

# training params
train_arg = add_argument_group('Training Params')
train_arg.add_argument('--start_epoch', type=int, default=0,
                        help='')
train_arg.add_argument('--max_epoch', type=int, default=12,
                        help='')
train_arg.add_argument('--init_lr', type=float, default=3e-4,
                        help='Initial learning rate value.')
train_arg.add_argument('--lr_factor', type=float, default=0.5,
                        help='Reduce learning rate value.')	
train_arg.add_argument('--momentum', type=float, default=0.9,
                        help='Nesterov momentum value.')			   
train_arg.add_argument('--display', type=int, default=50,
                        help='')

# loss function params
loss_arg = add_argument_group('Loss function Params')
loss_arg.add_argument('--score_weight', type=float, default=1.,
                        help='')
loss_arg.add_argument('--loc_weight', type=float, default=1.,
                        help='')
loss_arg.add_argument('--desc_weight', type=float, default=4.,
                        help='')
loss_arg.add_argument('--corres_weight', type=float, default=.5,
                        help='')
loss_arg.add_argument('--corres_threshold', type=int, default=4.,
                        help='')
					   
# other params
misc_arg = add_argument_group('Misc.')
misc_arg.add_argument('--use_gpu', type=str2bool, default=True,
                        help="Whether to run on the GPU.")
misc_arg.add_argument('--gpu', type=int, default=0,
                        help="Which GPU to run on.")										  
misc_arg.add_argument('--seed', type=int, default=1001,
                        help='Seed to ensure reproducibility.')					  
misc_arg.add_argument('--ckpt_dir', type=str, default='./checkpoints',
                        help='Directory in which to save model checkpoints.')					  

def get_config():
    config, unparsed = parser.parse_known_args()
    return config, unparsed