File size: 7,972 Bytes
bd86ed9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
import torch.backends.cudnn as cudnn
import os, sys
import argparse
import numpy as np
from tqdm import tqdm
from utils import post_process_depth, flip_lr, compute_errors
from networks.NewCRFDepth import NewCRFDepth
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
# Dataset
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
parser.add_argument('--input_height', type=int, help='input height', default=480)
parser.add_argument('--input_width', type=int, help='input width', default=640)
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
# Preprocessing
parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
# Eval
parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
if sys.argv.__len__() == 2:
arg_filename_with_prefix = '@' + sys.argv[1]
args = parser.parse_args([arg_filename_with_prefix])
else:
args = parser.parse_args()
if args.dataset == 'kitti' or args.dataset == 'nyu':
from dataloaders.dataloader import NewDataLoader
def eval(model, dataloader_eval, post_process=False):
eval_measures = torch.zeros(10).cuda()
for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
with torch.no_grad():
image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
gt_depth = eval_sample_batched['depth']
has_valid_depth = eval_sample_batched['has_valid_depth']
if not has_valid_depth:
# print('Invalid depth. continue.')
continue
pred_depths_r_list, _, _ = model(image)
if post_process:
image_flipped = flip_lr(image)
pred_depths_r_list_flipped, _, _ = model(image_flipped)
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
pred_depth = pred_depth.cpu().numpy().squeeze()
gt_depth = gt_depth.cpu().numpy().squeeze()
if args.do_kb_crop:
height, width = gt_depth.shape
top_margin = int(height - 352)
left_margin = int((width - 1216) / 2)
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
pred_depth = pred_depth_uncropped
pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
if args.garg_crop or args.eigen_crop:
gt_height, gt_width = gt_depth.shape
eval_mask = np.zeros(valid_mask.shape)
if args.garg_crop:
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
elif args.eigen_crop:
if args.dataset == 'kitti':
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
elif args.dataset == 'nyu':
eval_mask[45:471, 41:601] = 1
valid_mask = np.logical_and(valid_mask, eval_mask)
measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
eval_measures[:9] += torch.tensor(measures).cuda()
eval_measures[9] += 1
eval_measures_cpu = eval_measures.cpu()
cnt = eval_measures_cpu[9].item()
eval_measures_cpu /= cnt
print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
'sq_rel', 'log_rms', 'd1', 'd2',
'd3'))
for i in range(8):
print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
print('{:7.4f}'.format(eval_measures_cpu[8]))
return eval_measures_cpu
def main_worker(args):
# CRF model
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))
model = torch.nn.DataParallel(model)
model.cuda()
print("== Model Initialized")
if args.checkpoint_path != '':
if os.path.isfile(args.checkpoint_path):
print("== Loading checkpoint '{}'".format(args.checkpoint_path))
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
del checkpoint
else:
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
cudnn.benchmark = True
dataloader_eval = NewDataLoader(args, 'online_eval')
# ===== Evaluation ======
model.eval()
with torch.no_grad():
eval_measures = eval(model, dataloader_eval, post_process=True)
def main():
torch.cuda.empty_cache()
args.distributed = False
ngpus_per_node = torch.cuda.device_count()
if ngpus_per_node > 1:
print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
return -1
main_worker(args)
if __name__ == '__main__':
main()
|