Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,803 Bytes
5ed9923 7dfcb4e 5ed9923 633a18a 5ed9923 e50dece c64f462 c65a25a c64f462 c65a25a c64f462 5ed9923 ccd5a59 5ed9923 c65a25a 5ed9923 633a18a 5ed9923 49b3e3d 5ed9923 49b3e3d 5ed9923 c64f462 5ed9923 e9c04b9 5ed9923 12c2bcf 597cdce 5ed9923 12c2bcf 5ed9923 12c2bcf 597cdce 5ed9923 e6feb9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
#!/usr/bin/env python3
# The MASt3R Gradio demo, modified for predicting 3D Gaussian Splats
# --- Original License ---
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import functools
import os
import sys
import tempfile
import spaces
import gradio
import torch
from huggingface_hub import hf_hub_download
sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')
sys.path.append('src/pixelsplat_src')
from dust3r.utils.image import load_images
from mast3r.utils.misc import hash_md5
import main
import utils.export as export
@spaces.GPU(duration=60)
def get_reconstructed_scene(outdir, weights_path, silent, image_size, ios_mode, filelist):
# @TEMP: Temporarily instantiating the model here every time to avoid pickling issues with Hugging Face Spaces
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device)
assert len(filelist) == 1 or len(filelist) == 2, "Please provide one or two images"
if ios_mode:
filelist = [f[0] for f in filelist]
if len(filelist) == 1:
filelist = [filelist[0], filelist[0]]
imgs = load_images(filelist, size=image_size, verbose=not silent)
for img in imgs:
img['img'] = img['img'].to(device)
img['original_img'] = img['original_img'].to(device)
img['true_shape'] = torch.from_numpy(img['true_shape'])
model = model.to(device)
output = model(imgs[0], imgs[1])
pred1, pred2 = output
plyfile = os.path.join(outdir, 'gaussians.ply')
export.save_as_ply(pred1, pred2, plyfile)
return plyfile
if __name__ == '__main__':
image_size = 512
silent = False
ios_mode = True
model_name = "brandonsmart/splatt3r_v1.0"
filename = "epoch=19-step=1200.ckpt"
weights_path = hf_hub_download(repo_id=model_name, filename=filename)
chkpt_tag = hash_md5(weights_path)
# Define example inputs and their corresponding precalculated outputs
examples = [
["demo_examples/scannet++_1_img_1.jpg", "demo_examples/scannet++_1_img_2.jpg", "demo_examples/scannet++_1.ply"],
["demo_examples/scannet++_2_img_1.jpg", "demo_examples/scannet++_2_img_2.jpg", "demo_examples/scannet++_2.ply"],
["demo_examples/scannet++_3_img_1.jpg", "demo_examples/scannet++_3_img_2.jpg", "demo_examples/scannet++_3.ply"],
["demo_examples/scannet++_4_img_1.jpg", "demo_examples/scannet++_4_img_2.jpg", "demo_examples/scannet++_4.ply"],
["demo_examples/scannet++_5_img_1.jpg", "demo_examples/scannet++_5_img_2.jpg", "demo_examples/scannet++_5.ply"],
["demo_examples/scannet++_6_img_1.jpg", "demo_examples/scannet++_6_img_2.jpg", "demo_examples/scannet++_6.ply"],
["demo_examples/scannet++_7_img_1.jpg", "demo_examples/scannet++_7_img_2.jpg", "demo_examples/scannet++_7.ply"],
["demo_examples/scannet++_8_img_1.jpg", "demo_examples/scannet++_8_img_2.jpg", "demo_examples/scannet++_8.ply"],
["demo_examples/in_the_wild_1_img_1.jpg", "demo_examples/in_the_wild_1_img_2.jpg", "demo_examples/in_the_wild_1.ply"],
["demo_examples/in_the_wild_2_img_1.jpg", "demo_examples/in_the_wild_2_img_2.jpg", "demo_examples/in_the_wild_2.ply"],
["demo_examples/in_the_wild_3_img_1.jpg", "demo_examples/in_the_wild_3_img_2.jpg", "demo_examples/in_the_wild_3.ply"],
["demo_examples/in_the_wild_4_img_1.jpg", "demo_examples/in_the_wild_4_img_2.jpg", "demo_examples/in_the_wild_4.ply"],
["demo_examples/in_the_wild_5_img_1.jpg", "demo_examples/in_the_wild_5_img_2.jpg", "demo_examples/in_the_wild_5.ply"],
["demo_examples/in_the_wild_6_img_1.jpg", "demo_examples/in_the_wild_6_img_2.jpg", "demo_examples/in_the_wild_6.ply"],
["demo_examples/in_the_wild_7_img_1.jpg", "demo_examples/in_the_wild_7_img_2.jpg", "demo_examples/in_the_wild_7.ply"],
["demo_examples/in_the_wild_8_img_1.jpg", "demo_examples/in_the_wild_8_img_2.jpg", "demo_examples/in_the_wild_8.ply"],
]
for i in range(len(examples)):
for j in range(len(examples[i])):
examples[i][j] = hf_hub_download(repo_id=model_name, filename=examples[i][j])
with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:
cache_path = os.path.join(tmpdirname, chkpt_tag)
os.makedirs(cache_path, exist_ok=True)
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, weights_path, silent, image_size, ios_mode)
if not ios_mode:
for i in range(len(examples)):
examples[i].insert(2, (examples[i][0], examples[i][1]))
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
with gradio.Blocks(css=css, title="Splatt3R Demo") as demo:
gradio.HTML('<h2 style="text-align: center;">Splatt3R Demo</h2>')
with gradio.Column():
gradio.Markdown('''
Please upload exactly one or two images below to be used for reconstruction.
If non-square images are uploaded, they will be cropped to squares for reconstruction.
Due to the time taken to instantiate the model and download the file from Hugging Face, it may take up to 60 seconds to see your result.
Consider running the model locally from [here](https://github.com/btsmart/splatt3r) for near-instantaneous results.
''')
if ios_mode:
inputfiles = gradio.Gallery(type="filepath")
else:
inputfiles = gradio.File(file_count="multiple")
run_btn = gradio.Button("Run")
gradio.Markdown('''
## Output
Below we show the generated 3D Gaussian Splat.
The generated splats are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering.
As it downloads your previous generations may be visible.
The arrow in the top right of the window below can be used to download the .ply for rendering with other viewers,
such as [here](https://projects.markkellogg.org/threejs/demo_gaussian_splats_3d.php?art=1&cu=0,-1,0&cp=0,1,0&cla=1,0,0&aa=false&2d=false&sh=0) or [here](https://playcanvas.com/supersplat/editor).
''')
outmodel = gradio.Model3D(
clear_color=[1.0, 1.0, 1.0, 0.0],
)
run_btn.click(fn=recon_fun, inputs=[inputfiles], outputs=[outmodel])
gradio.Markdown('''
## Examples
A gallery of examples generated from ScanNet++ and from 'in the wild' images taken with a mobile phone.
These examples are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering.
As it downloads your previous generations may be visible.
''')
snapshot_1 = gradio.Image(None, visible=False)
snapshot_2 = gradio.Image(None, visible=False)
if ios_mode:
gradio.Examples(
examples=examples,
inputs=[snapshot_1, snapshot_2, outmodel],
examples_per_page=5
)
else:
gradio.Examples(
examples=examples,
inputs=[snapshot_1, snapshot_2, inputfiles, outmodel],
examples_per_page=5
)
demo.launch()
|