File size: 5,723 Bytes
aeca520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import pytorch_lightning as pl
import argparse
import pprint
from loguru import logger as loguru_logger

from src.config.default import get_cfg_defaults
from src.utils.profiler import build_profiler

from src.lightning.data import MultiSceneDataModule
from src.lightning.lightning_loftr import PL_LoFTR

import torch

def parse_args():
    # init a costum parser which will be added into pl.Trainer parser
    # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'data_cfg_path', type=str, help='data config path')
    parser.add_argument(
        'main_cfg_path', type=str, help='main config path')
    parser.add_argument(
        '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
    parser.add_argument(
        '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
    parser.add_argument(
        '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
    parser.add_argument(
        '--batch_size', type=int, default=1, help='batch_size per gpu')
    parser.add_argument(
        '--num_workers', type=int, default=2)
    parser.add_argument(
        '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
    parser.add_argument(
        '--pixel_thr', type=float, default=None, help='modify the RANSAC threshold.')
    parser.add_argument(
        '--ransac', type=str, default=None, help='modify the RANSAC method')
    parser.add_argument(
        '--scannetX', type=int, default=None, help='ScanNet resize X')
    parser.add_argument(
        '--scannetY', type=int, default=None, help='ScanNet resize Y')
    parser.add_argument(
        '--megasize', type=int, default=None, help='MegaDepth resize')
    parser.add_argument(
        '--npe', action='store_true', default=False, help='')
    parser.add_argument(
        '--fp32', action='store_true', default=False, help='')
    parser.add_argument(
        '--ransac_times', type=int, default=None, help='repeat ransac multiple times for more robust evaluation')
    parser.add_argument(
        '--rmbd', type=int, default=None, help='remove border matches')
    parser.add_argument(
        '--deter', action='store_true', default=False, help='use deterministic mode for testing')
    parser.add_argument(
        '--half', action='store_true', default=False, help='pure16')
    parser.add_argument(
        '--flash', action='store_true', default=False, help='flash')

    parser = pl.Trainer.add_argparse_args(parser)
    return parser.parse_args()

def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True

if __name__ == '__main__':
    # parse arguments
    args = parse_args()
    pprint.pprint(vars(args))

    # init default-cfg and merge it with the main- and data-cfg        
    config = get_cfg_defaults()
    config.merge_from_file(args.main_cfg_path)
    config.merge_from_file(args.data_cfg_path)
    if args.deter:
        torch.backends.cudnn.deterministic = True
    pl.seed_everything(config.TRAINER.SEED)  # reproducibility

    # tune when testing
    if args.thr is not None:
        config.LOFTR.MATCH_COARSE.THR = args.thr

    if args.scannetX is not None and args.scannetY is not None:
        config.DATASET.SCAN_IMG_RESIZEX = args.scannetX
        config.DATASET.SCAN_IMG_RESIZEY = args.scannetY
    if args.megasize is not None:
        config.DATASET.MGDPT_IMG_RESIZE = args.megasize

    if args.npe:
        if config.LOFTR.COARSE.ROPE:
            assert config.DATASET.NPE_NAME is not None
        if config.DATASET.NPE_NAME is not None:
            if config.DATASET.NPE_NAME == 'megadepth':
                config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.MGDPT_IMG_RESIZE, config.DATASET.MGDPT_IMG_RESIZE] # [832, 832, 1152, 1152]
            elif config.DATASET.NPE_NAME == 'scannet':
                config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.SCAN_IMG_RESIZEX, config.DATASET.SCAN_IMG_RESIZEX] # [832, 832, 640, 640]
    else:
        config.LOFTR.COARSE.NPE = [832, 832, 832, 832]

    if args.ransac_times is not None:
        config.LOFTR.EVAL_TIMES = args.ransac_times

    if args.rmbd is not None:
        config.LOFTR.MATCH_COARSE.BORDER_RM = args.rmbd

    if args.pixel_thr is not None:
        config.TRAINER.RANSAC_PIXEL_THR = args.pixel_thr

    if args.ransac is not None:
        config.TRAINER.POSE_ESTIMATION_METHOD = args.ransac
        if args.ransac == 'LO-RANSAC' and config.TRAINER.RANSAC_PIXEL_THR == 0.5:
            config.TRAINER.RANSAC_PIXEL_THR = 2.0

    if args.fp32:
        config.LOFTR.MP = False

    if args.half:
        config.LOFTR.HALF = True
        config.DATASET.FP16 = True
    else:
        config.LOFTR.HALF = False
        config.DATASET.FP16 = False

    if args.flash:
        config.LOFTR.COARSE.NO_FLASH = False

    loguru_logger.info(f"Args and config initialized!")

    # lightning module
    profiler = build_profiler(args.profiler_name)
    model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
    loguru_logger.info(f"LoFTR-lightning initialized!")

    # lightning data
    data_module = MultiSceneDataModule(args, config)
    loguru_logger.info(f"DataModule initialized!")

    # lightning trainer
    trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)

    loguru_logger.info(f"Start testing!")
    trainer.test(model, datamodule=data_module, verbose=False)