|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import math |
|
import gradio |
|
import os |
|
import torch |
|
import numpy as np |
|
import tempfile |
|
import functools |
|
import copy |
|
from tqdm import tqdm |
|
import cv2 |
|
from PIL import Image |
|
import os.path as path |
|
import sys |
|
import tempfile |
|
|
|
from dust3r.inference import inference |
|
from dust3r.model import AsymmetricCroCo3DStereo |
|
from dust3r.image_pairs import make_pairs |
|
from dust3r.utils.image_pose import load_images, rgb, enlarge_seg_masks, resize_numpy_image |
|
from dust3r.utils.device import to_numpy |
|
from dust3r.cloud_opt_flow import global_aligner, GlobalAlignerMode |
|
import matplotlib.pyplot as pl |
|
from transformers import pipeline |
|
from dust3r.utils.viz_demo import convert_scene_output_to_glb |
|
import depth_pro |
|
import spaces |
|
from huggingface_hub import hf_hub_download |
|
pl.ion() |
|
|
|
HERE_PATH = path.normpath(path.dirname(__file__)) |
|
sys.path.insert(0, HERE_PATH) |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
batch_size = 1 |
|
|
|
tmpdirname = tempfile.mkdtemp(suffix='_align3r_gradio_demo') |
|
image_size = 512 |
|
silent = False |
|
gradio_delete_cache = 7200 |
|
print(f'{HERE_PATH}/third_party/ml-depth-pro/checkpoints/') |
|
hf_hub_download(repo_id="apple/DepthPro", filename='depth_pro.pt', local_dir=f'{HERE_PATH}/third_party/ml-depth-pro/checkpoints/') |
|
|
|
class FileState: |
|
def __init__(self, outfile_name=None): |
|
self.outfile_name = outfile_name |
|
|
|
def __del__(self): |
|
if self.outfile_name is not None and os.path.isfile(self.outfile_name): |
|
os.remove(self.outfile_name) |
|
self.outfile_name = None |
|
|
|
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, |
|
clean_depth=False, transparent_cams=False, cam_size=0.05, show_cam=True, save_name=None, thr_for_init_conf=True): |
|
""" |
|
extract 3D_model (glb file) from a reconstructed scene |
|
""" |
|
if scene is None: |
|
return None |
|
|
|
if clean_depth: |
|
scene = scene.clean_pointcloud() |
|
if mask_sky: |
|
scene = scene.mask_sky() |
|
|
|
|
|
rgbimg = scene.imgs |
|
focals = scene.get_focals().cpu() |
|
cams2world = scene.get_im_poses().cpu() |
|
|
|
pts3d = to_numpy(scene.get_pts3d(raw_pts=True)) |
|
scene.min_conf_thr = min_conf_thr |
|
scene.thr_for_init_conf = thr_for_init_conf |
|
msk = to_numpy(scene.get_masks()) |
|
cmap = pl.get_cmap('viridis') |
|
cam_color = [cmap(i/len(rgbimg))[:3] for i in range(len(rgbimg))] |
|
cam_color = [(255*c[0], 255*c[1], 255*c[2]) for c in cam_color] |
|
return convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, |
|
transparent_cams=transparent_cams, cam_size=cam_size, show_cam=show_cam, silent=silent, save_name=save_name, |
|
cam_color=cam_color) |
|
|
|
|
|
def generate_monocular_depth_maps(img_list, depth_prior_name): |
|
depth_list = [] |
|
focallength_px_list = [] |
|
|
|
if depth_prior_name=='Depth Pro': |
|
model, transform = depth_pro.create_model_and_transforms(device='cuda') |
|
model.eval() |
|
|
|
for image_path in tqdm(img_list): |
|
|
|
image, _, f_px = depth_pro.load_rgb(image_path) |
|
image = transform(image) |
|
|
|
prediction = model.infer(image, f_px=f_px) |
|
depth = prediction["depth"].cpu().numpy() |
|
focallength_px=prediction["focallength_px"].cpu() |
|
depth = resize_numpy_image(depth, image.size) |
|
depth_list.append(depth) |
|
focallength_px_list.append(focallength_px) |
|
|
|
elif depth_prior_name=='Depth Anything V2': |
|
pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Large-hf",device='cuda') |
|
for image_path in tqdm(img_list): |
|
|
|
image = Image.open(image_path) |
|
|
|
depth = pipe(image)["predicted_depth"].numpy() |
|
print(depth.max(),depth.min()) |
|
depth = cv2.resize(depth[0], image.size, interpolation=cv2.INTER_LANCZOS4) |
|
focallength_px = 200 |
|
print(depth.max(),depth.min()) |
|
depth_list.append(depth) |
|
focallength_px_list.append(focallength_px) |
|
|
|
return depth_list, focallength_px_list |
|
|
|
@spaces.GPU(duration=180) |
|
def local_get_reconstructed_scene(filelist, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name, **kw): |
|
depth_list, focallength_px_list = generate_monocular_depth_maps(filelist, depth_prior_name) |
|
imgs = load_images(filelist, depth_list, focallength_px_list, size=image_size, verbose=not silent,traj_format='custom', depth_prior_name=depth_prior_name) |
|
|
|
|
|
|
|
scenegraph_type = 'swinstride-5-noncyclic' |
|
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) |
|
if depth_prior_name == "Depth Pro": |
|
weights_path = "cyun9286/Align3R_DepthPro_ViTLarge_BaseDecoder_512_dpt" |
|
else: |
|
weights_path = "cyun9286/Align3R_DepthAnythingV2_ViTLarge_BaseDecoder_512_dpt" |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device) |
|
output = inference(pairs, model, device, batch_size=batch_size, verbose=not silent) |
|
mode = GlobalAlignerMode.PointCloudOptimizer |
|
scene = global_aligner(output, device=device, mode=mode, verbose=not silent, shared_focal = True, temporal_smoothing_weight=0.01, translation_weight=1.0, |
|
flow_loss_weight=0.01, flow_loss_start_epoch=0.1, flow_loss_thre=25, use_self_mask=True, |
|
num_total_iter=300, empty_cache= len(filelist) > 72) |
|
lr = 0.01 |
|
if mode == GlobalAlignerMode.PointCloudOptimizer: |
|
loss = scene.compute_global_alignment(init='mst', niter=300, schedule='linear', lr=lr) |
|
|
|
|
|
|
|
outfile = get_3D_model_from_scene(tmpdirname, silent, scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size) |
|
|
|
return outfile |
|
|
|
|
|
def run_example(snapshot, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name, inputfiles, **kw): |
|
return local_get_reconstructed_scene(inputfiles, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name, **kw) |
|
|
|
css = """.gradio-container {margin: 0 !important; min-width: 100%};""" |
|
title = "Align3R Demo" |
|
with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo: |
|
filestate = gradio.State(None) |
|
gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with Align3R</h2>') |
|
gradio.HTML('<p>Upload two images (wait for them to be fully uploaded before hitting the run button). ' |
|
'If you want to try larger image collections, you can find the more complete version of this demo that you can run locally ' |
|
'and more details about the method at <a href="https://github.com/jiah-cloud/Align3R">github.com/jiah-cloud/Align3R</a>. ' |
|
'The checkpoint used in this demo is available at <a href="https://huggingface.co/cyun9286/Align3R_DepthAnythingV2_ViTLarge_BaseDecoder_512_dpt">Align3R (Depth Anything V2)</a> and <a href="https://huggingface.co/cyun9286/Align3R_DepthPro_ViTLarge_BaseDecoder_512_dpt">Align3R (Depth Pro)</a>.</p>') |
|
with gradio.Column(): |
|
inputfiles = gradio.File(file_count="multiple") |
|
snapshot = gradio.Image(None, visible=False) |
|
with gradio.Row(): |
|
|
|
cam_size = gradio.Slider(label="cam_size", value=0.02, minimum=0.001, maximum=1.0, step=0.001) |
|
|
|
depth_prior_name = gradio.Dropdown( |
|
["Depth Pro", "Depth Anything V2"], label="monocular depth estimation model", info="Select the monocular depth estimation model.") |
|
min_conf_thr = gradio.Slider(label="min_conf_thr", value=2, minimum=0.0, maximum=20, step=0.01) |
|
with gradio.Row(): |
|
as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud") |
|
mask_sky = gradio.Checkbox(value=True, label="Mask sky") |
|
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") |
|
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") |
|
|
|
show_cam = gradio.Checkbox(value=True, label="Show Camera") |
|
run_btn = gradio.Button("Run") |
|
outmodel = gradio.Model3D() |
|
|
|
examples = gradio.Examples( |
|
examples=[ |
|
[ |
|
os.path.join(HERE_PATH, 'example/bear/00000.jpg'), |
|
2, True, True, True, False, 0.02, "Depth Anything V2", |
|
[os.path.join(HERE_PATH, 'example/bear/00000.jpg'), |
|
os.path.join(HERE_PATH, 'example/bear/00001.jpg'), |
|
os.path.join(HERE_PATH, 'example/bear/00002.jpg'), |
|
] |
|
], |
|
[ |
|
os.path.join(HERE_PATH, 'example/breakdance/00000.jpg'), |
|
2, True, True, True, False, 0.02, "Depth Anything V2", |
|
[os.path.join(HERE_PATH, 'example/breakdance/00000.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00001.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00002.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00003.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00004.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00005.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00006.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00007.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00008.jpg'), |
|
os.path.join(HERE_PATH, 'example/breakdance/00009.jpg'), |
|
] |
|
], |
|
[ |
|
os.path.join(HERE_PATH, 'example/tennis/00000.jpg'), |
|
2, True, True, True, False, 0.02, "Depth Anything V2", |
|
[os.path.join(HERE_PATH, 'example/tennis/00000.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00001.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00002.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00003.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00004.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00005.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00006.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00007.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00008.jpg'), |
|
os.path.join(HERE_PATH, 'example/tennis/00009.jpg'), |
|
] |
|
], |
|
[ |
|
os.path.join(HERE_PATH, 'example/camel/00000.jpg'), |
|
2, True, True, True, False, 0.02, "Depth Anything V2", |
|
[os.path.join(HERE_PATH, 'example/camel/00000.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00001.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00002.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00003.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00004.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00005.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00006.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00007.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00008.jpg'), |
|
os.path.join(HERE_PATH, 'example/camel/00009.jpg'), |
|
] |
|
], |
|
], |
|
inputs=[snapshot, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name, inputfiles], |
|
outputs=[outmodel], |
|
fn=run_example, |
|
cache_examples="lazy", |
|
) |
|
|
|
|
|
run_btn.click(fn=local_get_reconstructed_scene, |
|
inputs=[inputfiles, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name], |
|
outputs=[outmodel]) |
|
|
|
demo.launch(show_error=True, share=None, server_name=None, server_port=None) |
|
shutil.rmtree(tmpdirname) |