import argparse


def str2bool(v):
    return v.lower() in ("true", "1")


arg_lists = []
parser = argparse.ArgumentParser()


def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg


# -----------------------------------------------------------------------------
# Network
net_arg = add_argument_group("Network")
net_arg.add_argument(
    "--model_name", type=str, default="SGM", help="" "model for training"
)
net_arg.add_argument(
    "--config_path",
    type=str,
    default="configs/sgm.yaml",
    help="" "config path for model",
)

# -----------------------------------------------------------------------------
# Data
data_arg = add_argument_group("Data")
data_arg.add_argument(
    "--rawdata_path", type=str, default="rawdata", help="" "path for rawdata"
)
data_arg.add_argument(
    "--dataset_path", type=str, default="dataset", help="" "path for dataset"
)
data_arg.add_argument(
    "--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir"
)
data_arg.add_argument(
    "--num_kpt", type=int, default=1000, help="" "number of kpt for training"
)
data_arg.add_argument(
    "--input_normalize",
    type=str,
    default="img",
    help="" "normalize type for input kpt, img or intrinsic",
)
data_arg.add_argument(
    "--data_aug",
    type=str2bool,
    default=True,
    help="" "apply kpt coordinate homography augmentation",
)
data_arg.add_argument(
    "--desc_suffix", type=str, default="suffix", help="" "desc file suffix"
)


# -----------------------------------------------------------------------------
# Loss
loss_arg = add_argument_group("loss")
loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum")
loss_arg.add_argument(
    "--seed_loss_weight",
    type=float,
    default=250,
    help="" "confidence loss weight for sgm",
)
loss_arg.add_argument(
    "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm"
)
loss_arg.add_argument(
    "--inlier_th",
    type=float,
    default=5e-3,
    help="" "inlier threshold for epipolar distance (for sgm and visualization)",
)


# -----------------------------------------------------------------------------
# Training
train_arg = add_argument_group("Train")
train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate")
train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size")
train_arg.add_argument(
    "--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES"
)
train_arg.add_argument(
    "--train_iter", type=int, default=1000000, help="" "training iterations to perform"
)
train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path")
train_arg.add_argument(
    "--val_intv", type=int, default=20000, help="" "validation interval"
)
train_arg.add_argument(
    "--save_intv", type=int, default=1000, help="" "summary interval"
)
train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval")
train_arg.add_argument(
    "--decay_rate", type=float, default=0.999996, help="" "lr decay rate"
)
train_arg.add_argument(
    "--decay_iter", type=float, default=300000, help="" "lr decay iter"
)
train_arg.add_argument(
    "--local_rank", type=int, default=0, help="" "local rank for ddp"
)
train_arg.add_argument(
    "--train_vis_folder",
    type=str,
    default=".",
    help="" "visualization folder during training",
)

# -----------------------------------------------------------------------------
# Visualization
vis_arg = add_argument_group("Visualization")
vis_arg.add_argument(
    "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar"
)


def get_config():
    config, unparsed = parser.parse_known_args()
    return config, unparsed


def print_usage():
    parser.print_usage()


#
# config.py ends here