umuthopeyildirim's picture
here we go
bd86ed9
raw
history blame
7.97 kB
import torch
import torch.backends.cudnn as cudnn
import os, sys
import argparse
import numpy as np
from tqdm import tqdm
from utils import post_process_depth, flip_lr, compute_errors
from networks.NewCRFDepth import NewCRFDepth
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
# Dataset
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
parser.add_argument('--input_height', type=int, help='input height', default=480)
parser.add_argument('--input_width', type=int, help='input width', default=640)
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
# Preprocessing
parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
# Eval
parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
if sys.argv.__len__() == 2:
arg_filename_with_prefix = '@' + sys.argv[1]
args = parser.parse_args([arg_filename_with_prefix])
else:
args = parser.parse_args()
if args.dataset == 'kitti' or args.dataset == 'nyu':
from dataloaders.dataloader import NewDataLoader
def eval(model, dataloader_eval, post_process=False):
eval_measures = torch.zeros(10).cuda()
for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
with torch.no_grad():
image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
gt_depth = eval_sample_batched['depth']
has_valid_depth = eval_sample_batched['has_valid_depth']
if not has_valid_depth:
# print('Invalid depth. continue.')
continue
pred_depths_r_list, _, _ = model(image)
if post_process:
image_flipped = flip_lr(image)
pred_depths_r_list_flipped, _, _ = model(image_flipped)
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
pred_depth = pred_depth.cpu().numpy().squeeze()
gt_depth = gt_depth.cpu().numpy().squeeze()
if args.do_kb_crop:
height, width = gt_depth.shape
top_margin = int(height - 352)
left_margin = int((width - 1216) / 2)
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
pred_depth = pred_depth_uncropped
pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
if args.garg_crop or args.eigen_crop:
gt_height, gt_width = gt_depth.shape
eval_mask = np.zeros(valid_mask.shape)
if args.garg_crop:
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
elif args.eigen_crop:
if args.dataset == 'kitti':
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
elif args.dataset == 'nyu':
eval_mask[45:471, 41:601] = 1
valid_mask = np.logical_and(valid_mask, eval_mask)
measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
eval_measures[:9] += torch.tensor(measures).cuda()
eval_measures[9] += 1
eval_measures_cpu = eval_measures.cpu()
cnt = eval_measures_cpu[9].item()
eval_measures_cpu /= cnt
print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
'sq_rel', 'log_rms', 'd1', 'd2',
'd3'))
for i in range(8):
print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
print('{:7.4f}'.format(eval_measures_cpu[8]))
return eval_measures_cpu
def main_worker(args):
# CRF model
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))
model = torch.nn.DataParallel(model)
model.cuda()
print("== Model Initialized")
if args.checkpoint_path != '':
if os.path.isfile(args.checkpoint_path):
print("== Loading checkpoint '{}'".format(args.checkpoint_path))
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
del checkpoint
else:
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
cudnn.benchmark = True
dataloader_eval = NewDataLoader(args, 'online_eval')
# ===== Evaluation ======
model.eval()
with torch.no_grad():
eval_measures = eval(model, dataloader_eval, post_process=True)
def main():
torch.cuda.empty_cache()
args.distributed = False
ngpus_per_node = torch.cuda.device_count()
if ngpus_per_node > 1:
print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
return -1
main_worker(args)
if __name__ == '__main__':
main()