|
|
|
"""Sample script to run DepthPro. |
|
|
|
Copyright (C) 2024 Apple Inc. All Rights Reserved. |
|
""" |
|
|
|
|
|
import argparse |
|
import logging |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from matplotlib import pyplot as plt |
|
from tqdm import tqdm |
|
|
|
from depth_pro import create_model_and_transforms, load_rgb |
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
def get_torch_device() -> torch.device: |
|
"""Get the Torch device.""" |
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
elif torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
return device |
|
|
|
|
|
def run(args): |
|
"""Run Depth Pro on a sample image.""" |
|
if args.verbose: |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
model, transform = create_model_and_transforms( |
|
device=get_torch_device(), |
|
precision=torch.half, |
|
) |
|
model.eval() |
|
|
|
image_paths = [args.image_path] |
|
if args.image_path.is_dir(): |
|
image_paths = args.image_path.glob("**/*") |
|
relative_path = args.image_path |
|
else: |
|
relative_path = args.image_path.parent |
|
|
|
if not args.skip_display: |
|
plt.ion() |
|
fig = plt.figure() |
|
ax_rgb = fig.add_subplot(121) |
|
ax_disp = fig.add_subplot(122) |
|
|
|
for image_path in tqdm(image_paths): |
|
|
|
try: |
|
LOGGER.info(f"Loading image {image_path} ...") |
|
image, _, f_px = load_rgb(image_path) |
|
except Exception as e: |
|
LOGGER.error(str(e)) |
|
continue |
|
|
|
|
|
prediction = model.infer(transform(image), f_px=f_px) |
|
|
|
|
|
depth = prediction["depth"].detach().cpu().numpy().squeeze() |
|
if f_px is not None: |
|
LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}") |
|
elif prediction["focallength_px"] is not None: |
|
focallength_px = prediction["focallength_px"].detach().cpu().item() |
|
LOGGER.info(f"Estimated focal length: {focallength_px}") |
|
|
|
inverse_depth = 1 / depth |
|
|
|
max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1) |
|
min_invdepth_vizu = max(1 / 250, inverse_depth.min()) |
|
inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / ( |
|
max_invdepth_vizu - min_invdepth_vizu |
|
) |
|
|
|
|
|
if args.output_path is not None: |
|
output_file = ( |
|
args.output_path |
|
/ image_path.relative_to(relative_path).parent |
|
/ image_path.stem |
|
) |
|
LOGGER.info(f"Saving depth map to: {str(output_file)}") |
|
output_file.parent.mkdir(parents=True, exist_ok=True) |
|
np.savez_compressed(output_file, depth=depth) |
|
|
|
|
|
cmap = plt.get_cmap("turbo") |
|
color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype( |
|
np.uint8 |
|
) |
|
color_map_output_file = str(output_file) + ".jpg" |
|
LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}") |
|
PIL.Image.fromarray(color_depth).save( |
|
color_map_output_file, format="JPEG", quality=90 |
|
) |
|
|
|
|
|
if not args.skip_display: |
|
ax_rgb.imshow(image) |
|
ax_disp.imshow(inverse_depth_normalized, cmap="turbo") |
|
fig.canvas.draw() |
|
fig.canvas.flush_events() |
|
|
|
LOGGER.info("Done predicting depth!") |
|
if not args.skip_display: |
|
plt.show(block=True) |
|
|
|
|
|
def main(): |
|
"""Run DepthPro inference example.""" |
|
parser = argparse.ArgumentParser( |
|
description="Inference scripts of DepthPro with PyTorch models." |
|
) |
|
parser.add_argument( |
|
"-i", |
|
"--image-path", |
|
type=Path, |
|
default="./data/example.jpg", |
|
help="Path to input image.", |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--output-path", |
|
type=Path, |
|
help="Path to store output files.", |
|
) |
|
parser.add_argument( |
|
"--skip-display", |
|
action="store_true", |
|
help="Skip matplotlib display.", |
|
) |
|
parser.add_argument( |
|
"-v", |
|
"--verbose", |
|
action="store_true", |
|
help="Show verbose output." |
|
) |
|
|
|
run(parser.parse_args()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|