|
"""Compute depth maps for images in the input folder. |
|
""" |
|
import os |
|
import ntpath |
|
import glob |
|
import torch |
|
import utils |
|
import cv2 |
|
import numpy as np |
|
from torchvision.transforms import Compose, Normalize |
|
from torchvision import transforms |
|
|
|
from shutil import copyfile |
|
import fileinput |
|
import sys |
|
sys.path.append(os.getcwd() + '/..') |
|
|
|
def modify_file(): |
|
modify_filename = '../midas/blocks.py' |
|
copyfile(modify_filename, modify_filename+'.bak') |
|
|
|
with open(modify_filename, 'r') as file : |
|
filedata = file.read() |
|
|
|
filedata = filedata.replace('align_corners=True', 'align_corners=False') |
|
filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models') |
|
filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()') |
|
|
|
with open(modify_filename, 'w') as file: |
|
file.write(filedata) |
|
|
|
def restore_file(): |
|
modify_filename = '../midas/blocks.py' |
|
copyfile(modify_filename+'.bak', modify_filename) |
|
|
|
modify_file() |
|
|
|
from midas.midas_net import MidasNet |
|
from midas.transforms import Resize, NormalizeImage, PrepareForNet |
|
|
|
restore_file() |
|
|
|
|
|
class MidasNet_preprocessing(MidasNet): |
|
"""Network for monocular depth estimation. |
|
""" |
|
def forward(self, x): |
|
"""Forward pass. |
|
|
|
Args: |
|
x (tensor): input data (image) |
|
|
|
Returns: |
|
tensor: depth |
|
""" |
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]) |
|
std = torch.tensor([0.229, 0.224, 0.225]) |
|
x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) |
|
|
|
return MidasNet.forward(self, x) |
|
|
|
|
|
def run(model_path): |
|
"""Run MonoDepthNN to compute depth maps. |
|
|
|
Args: |
|
model_path (str): path to saved model |
|
""" |
|
print("initialize") |
|
|
|
|
|
|
|
|
|
|
|
model = MidasNet_preprocessing(model_path, non_negative=True) |
|
|
|
model.eval() |
|
|
|
print("start processing") |
|
|
|
|
|
img_input = np.zeros((3, 384, 384), np.float32) |
|
|
|
|
|
with torch.no_grad(): |
|
sample = torch.from_numpy(img_input).unsqueeze(0) |
|
prediction = model.forward(sample) |
|
prediction = ( |
|
torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=img_input.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
.squeeze() |
|
.cpu() |
|
.numpy() |
|
) |
|
|
|
torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9) |
|
|
|
print("finished") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
MODEL_PATH = "../model-f6b98070.pt" |
|
|
|
|
|
run(MODEL_PATH) |
|
|