import argparse def str2bool(v): return v.lower() in ("true", "1") arg_lists = [] parser = argparse.ArgumentParser() def add_argument_group(name): arg = parser.add_argument_group(name) arg_lists.append(arg) return arg # ----------------------------------------------------------------------------- # Network net_arg = add_argument_group("Network") net_arg.add_argument( "--model_name", type=str,default='SGM', help="" "model for training") net_arg.add_argument( "--config_path", type=str,default='configs/sgm.yaml', help="" "config path for model") # ----------------------------------------------------------------------------- # Data data_arg = add_argument_group("Data") data_arg.add_argument( "--rawdata_path", type=str, default='rawdata', help="" "path for rawdata") data_arg.add_argument( "--dataset_path", type=str, default='dataset', help="" "path for dataset") data_arg.add_argument( "--desc_path", type=str, default='desc', help="" "path for descriptor(kpt) dir") data_arg.add_argument( "--num_kpt", type=int, default=1000, help="" "number of kpt for training") data_arg.add_argument( "--input_normalize", type=str, default='img', help="" "normalize type for input kpt, img or intrinsic") data_arg.add_argument( "--data_aug", type=str2bool, default=True, help="" "apply kpt coordinate homography augmentation") data_arg.add_argument( "--desc_suffix", type=str, default='suffix', help="" "desc file suffix") # ----------------------------------------------------------------------------- # Loss loss_arg = add_argument_group("loss") loss_arg.add_argument( "--momentum", type=float, default=0.9, help="" "momentum") loss_arg.add_argument( "--seed_loss_weight", type=float, default=250, help="" "confidence loss weight for sgm") loss_arg.add_argument( "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm") loss_arg.add_argument( "--inlier_th", type=float, default=5e-3, help="" "inlier threshold for epipolar distance (for sgm and visualization)") # ----------------------------------------------------------------------------- # Training train_arg = add_argument_group("Train") train_arg.add_argument( "--train_lr", type=float, default=1e-4, help="" "learning rate") train_arg.add_argument( "--train_batch_size", type=int, default=16, help="" "batch size") train_arg.add_argument( "--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES') train_arg.add_argument( "--train_iter", type=int, default=1000000, help="" "training iterations to perform") train_arg.add_argument( "--log_base", type=str, default="./log/", help="" "log path") train_arg.add_argument( "--val_intv", type=int, default=20000, help="" "validation interval") train_arg.add_argument( "--save_intv", type=int, default=1000, help="" "summary interval") train_arg.add_argument( "--log_intv", type=int, default=100, help="" "log interval") train_arg.add_argument( "--decay_rate", type=float, default=0.999996, help="" "lr decay rate") train_arg.add_argument( "--decay_iter", type=float, default=300000, help="" "lr decay iter") train_arg.add_argument( "--local_rank", type=int, default=0, help="" "local rank for ddp") train_arg.add_argument( "--train_vis_folder", type=str, default='.', help="" "visualization folder during training") # ----------------------------------------------------------------------------- # Visualization vis_arg = add_argument_group('Visualization') vis_arg.add_argument( "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar" ) def get_config(): config, unparsed = parser.parse_known_args() return config, unparsed def print_usage(): parser.print_usage() # # config.py ends here