#!/usr/bin/env python3 # The MASt3R Gradio demo, modified for predicting 3D Gaussian Splats # --- Original License --- # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). import functools import os import sys import tempfile import gradio import torch from huggingface_hub import hf_hub_download sys.path.append('src/mast3r_src') sys.path.append('src/mast3r_src/dust3r') sys.path.append('src/pixelsplat_src') from dust3r.utils.image import load_images from mast3r.utils.misc import hash_md5 import main import utils.export as export def get_reconstructed_scene(outdir, model, device, silent, image_size, ios_mode, filelist): if ios_mode: filelist = [f[0] for f in filelist] if len(filelist) == 1: filelist = [filelist[0], filelist[0]] assert len(filelist) == 2, "Please provide two images" imgs = load_images(filelist, size=image_size, verbose=not silent) for img in imgs: img['img'] = img['img'].to(device) img['original_img'] = img['original_img'].to(device) img['true_shape'] = torch.from_numpy(img['true_shape']) output = model(imgs[0], imgs[1]) pred1, pred2 = output plyfile = os.path.join(outdir, 'gaussians.ply') export.save_as_ply(pred1, pred2, plyfile) return plyfile if __name__ == '__main__': image_size = 512 silent = False ios_mode = True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_name = "brandonsmart/splatt3r_v1.0" filename = "epoch=19-step=1200.ckpt" weights_path = hf_hub_download(repo_id=model_name, filename=filename) model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device) chkpt_tag = hash_md5(weights_path) # Define example inputs and their corresponding precalculated outputs examples = [ ["demo_examples/scannet++_1_img_1.jpg", "demo_examples/scannet++_1_img_2.jpg", "demo_examples/scannet++_1.ply"], ["demo_examples/scannet++_2_img_1.jpg", "demo_examples/scannet++_2_img_2.jpg", "demo_examples/scannet++_2.ply"], ["demo_examples/scannet++_3_img_1.jpg", "demo_examples/scannet++_3_img_2.jpg", "demo_examples/scannet++_3.ply"], ["demo_examples/scannet++_4_img_1.jpg", "demo_examples/scannet++_4_img_2.jpg", "demo_examples/scannet++_4.ply"], ["demo_examples/scannet++_5_img_1.jpg", "demo_examples/scannet++_5_img_2.jpg", "demo_examples/scannet++_5.ply"], ["demo_examples/scannet++_6_img_1.jpg", "demo_examples/scannet++_6_img_2.jpg", "demo_examples/scannet++_6.ply"], ["demo_examples/scannet++_7_img_1.jpg", "demo_examples/scannet++_7_img_2.jpg", "demo_examples/scannet++_7.ply"], ["demo_examples/scannet++_8_img_1.jpg", "demo_examples/scannet++_8_img_2.jpg", "demo_examples/scannet++_8.ply"], ["demo_examples/in_the_wild_1_img_1.jpg", "demo_examples/in_the_wild_1_img_2.jpg", "demo_examples/in_the_wild_1.ply"], ["demo_examples/in_the_wild_2_img_1.jpg", "demo_examples/in_the_wild_2_img_2.jpg", "demo_examples/in_the_wild_2.ply"], ["demo_examples/in_the_wild_3_img_1.jpg", "demo_examples/in_the_wild_3_img_2.jpg", "demo_examples/in_the_wild_3.ply"], ["demo_examples/in_the_wild_4_img_1.jpg", "demo_examples/in_the_wild_4_img_2.jpg", "demo_examples/in_the_wild_4.ply"], ["demo_examples/in_the_wild_5_img_1.jpg", "demo_examples/in_the_wild_5_img_2.jpg", "demo_examples/in_the_wild_5.ply"], ["demo_examples/in_the_wild_6_img_1.jpg", "demo_examples/in_the_wild_6_img_2.jpg", "demo_examples/in_the_wild_6.ply"], ["demo_examples/in_the_wild_7_img_1.jpg", "demo_examples/in_the_wild_7_img_2.jpg", "demo_examples/in_the_wild_7.ply"], ["demo_examples/in_the_wild_8_img_1.jpg", "demo_examples/in_the_wild_8_img_2.jpg", "demo_examples/in_the_wild_8.ply"], ] for i in range(len(examples)): for j in range(len(examples[i])): examples[i][j] = hf_hub_download(repo_id=model_name, filename=examples[i][j]) with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname: cache_path = os.path.join(tmpdirname, chkpt_tag) os.makedirs(cache_path, exist_ok=True) recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size, ios_mode) if not ios_mode: for i in range(len(examples)): examples[i].insert(2, (examples[i][0], examples[i][1])) css = """.gradio-container {margin: 0 !important; min-width: 100%};""" with gradio.Blocks(css=css, title="Splatt3R Demo") as demo: gradio.HTML('