File size: 4,124 Bytes
bd86ed9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.backends.cudnn as cudnn
import os, sys
import argparse
import numpy as np
from tqdm import tqdm
from utils import post_process_depth, flip_lr, compute_errors
from networks.NewCRFDepth import NewCRFDepth
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
parser.add_argument('--image_path', type=str, help='path to the image for inference', required=False)
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
if sys.argv.__len__() == 2:
arg_filename_with_prefix = '@' + sys.argv[1]
args = parser.parse_args([arg_filename_with_prefix])
else:
args = parser.parse_args()
def inference(model, post_process=False):
image = np.asarray(Image.open(args.image_path), dtype=np.float32) / 255.0
if args.dataset == 'kitti':
height = image.shape[0]
width = image.shape[1]
top_margin = int(height - 352)
left_margin = int((width - 1216) / 2)
image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
image = torch.from_numpy(image.transpose((2, 0, 1)))
image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
with torch.no_grad():
image = torch.autograd.Variable(image.unsqueeze(0).cuda())
pred_depths_r_list, _, _ = model(image)
if post_process:
image_flipped = flip_lr(image)
pred_depths_r_list_flipped, _, _ = model(image_flipped)
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
pred_depth = pred_depth.cpu().numpy().squeeze()
if args.dataset == 'kitti':
plt.imsave('depth.png', np.log10(pred_depth), cmap='magma')
else:
plt.imsave('depth.png', pred_depth, cmap='jet')
def main_worker(args):
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))
model = torch.nn.DataParallel(model)
model.cuda()
print("== Model Initialized")
if args.checkpoint_path != '':
if os.path.isfile(args.checkpoint_path):
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
del checkpoint
else:
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
cudnn.benchmark = True
# ===== Inference ======
model.eval()
with torch.no_grad():
inference(model, post_process=True)
def main():
torch.cuda.empty_cache()
args.distributed = False
ngpus_per_node = torch.cuda.device_count()
if ngpus_per_node > 1:
print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
return -1
main_worker(args)
if __name__ == '__main__':
main()
|