Spaces:
Running
on
Zero
Running
on
Zero
import os, subprocess, shlex, sys, gc | |
import time | |
import torch | |
import numpy as np | |
import shutil | |
import argparse | |
import gradio as gr | |
import uuid | |
import spaces | |
subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r"))) | |
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
from dust3r.inference import inference | |
from dust3r.model import AsymmetricCroCo3DStereo | |
from dust3r.utils.device import to_numpy | |
from dust3r.image_pairs import make_pairs | |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images | |
from argparse import ArgumentParser, Namespace | |
from arguments import ModelParams, PipelineParams, OptimizationParams | |
from train_joint import training | |
from render_by_interp import render_sets | |
GRADIO_CACHE_FOLDER = './gradio_cache_folder' | |
############################################################################################################################################# | |
def get_dust3r_args_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") | |
parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights") | |
parser.add_argument("--device", type=str, default='cuda', help="pytorch device") | |
parser.add_argument("--batch_size", type=int, default=1) | |
parser.add_argument("--schedule", type=str, default='linear') | |
parser.add_argument("--lr", type=float, default=0.01) | |
parser.add_argument("--niter", type=int, default=300) | |
parser.add_argument("--focal_avg", type=bool, default=True) | |
parser.add_argument("--n_views", type=int, default=3) | |
parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER) | |
return parser | |
def process(inputfiles, input_path=None): | |
if input_path is not None: | |
imgs_path = './assets/example/' + input_path | |
imgs_names = sorted(os.listdir(imgs_path)) | |
inputfiles = [] | |
for imgs_name in imgs_names: | |
file_path = os.path.join(imgs_path, imgs_name) | |
print(file_path) | |
inputfiles.append(file_path) | |
print(inputfiles) | |
# ------ (1) Coarse Geometric Initialization ------ | |
# os.system(f"rm -rf {GRADIO_CACHE_FOLDER}") | |
parser = get_dust3r_args_parser() | |
opt = parser.parse_args() | |
tmp_user_folder = str(uuid.uuid4()).replace("-", "") | |
opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder) | |
img_folder_path = os.path.join(opt.img_base_path, "images") | |
img_folder_path = os.path.join(opt.img_base_path, "images") | |
model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device) | |
os.makedirs(img_folder_path, exist_ok=True) | |
opt.n_views = len(inputfiles) | |
if opt.n_views == 1: | |
raise gr.Error("The number of input images should be greater than 1.") | |
print("Multiple images: ", inputfiles) | |
for image_path in inputfiles: | |
if input_path is not None: | |
shutil.copy(image_path, img_folder_path) | |
else: | |
shutil.move(image_path, img_folder_path) | |
train_img_list = sorted(os.listdir(img_folder_path)) | |
assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}" | |
images, ori_size, imgs_resolution = load_images(img_folder_path, size=512) | |
resolutions_are_equal = len(set(imgs_resolution)) == 1 | |
if resolutions_are_equal == False: | |
raise gr.Error("The resolution of the input image should be the same.") | |
print("ori_size", ori_size) | |
start_time = time.time() | |
###################################################### | |
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True) | |
output = inference(pairs, model, opt.device, batch_size=opt.batch_size) | |
output_colmap_path=img_folder_path.replace("images", "sparse/0") | |
os.makedirs(output_colmap_path, exist_ok=True) | |
scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer) | |
loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg) | |
scene = scene.clean_pointcloud() | |
imgs = to_numpy(scene.imgs) | |
focals = scene.get_focals() | |
poses = to_numpy(scene.get_im_poses()) | |
pts3d = to_numpy(scene.get_pts3d()) | |
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0))) | |
confidence_masks = to_numpy(scene.get_masks()) | |
intrinsics = to_numpy(scene.get_intrinsics()) | |
###################################################### | |
end_time = time.time() | |
print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds") | |
save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt')) | |
save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list) | |
pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)]) | |
color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)]) | |
color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8) | |
storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs) | |
pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3) | |
np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all) | |
np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu())) | |
### save VRAM | |
del scene | |
torch.cuda.empty_cache() | |
gc.collect() | |
################################################################################################################################################## | |
# ------ (2) Fast 3D-Gaussian Optimization ------ | |
parser = ArgumentParser(description="Training script parameters") | |
lp = ModelParams(parser) | |
op = OptimizationParams(parser) | |
pp = PipelineParams(parser) | |
parser.add_argument('--debug_from', type=int, default=-1) | |
parser.add_argument("--test_iterations", nargs="+", type=int, default=[]) | |
parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) | |
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) | |
parser.add_argument("--start_checkpoint", type=str, default = None) | |
parser.add_argument("--scene", type=str, default="demo") | |
parser.add_argument("--n_views", type=int, default=3) | |
parser.add_argument("--get_video", action="store_true") | |
parser.add_argument("--optim_pose", type=bool, default=True) | |
parser.add_argument("--skip_train", action="store_true") | |
parser.add_argument("--skip_test", action="store_true") | |
args = parser.parse_args(sys.argv[1:]) | |
args.save_iterations.append(args.iterations) | |
args.model_path = opt.img_base_path + '/output/' | |
args.source_path = opt.img_base_path | |
# args.model_path = GRADIO_CACHE_FOLDER + '/output/' | |
# args.source_path = GRADIO_CACHE_FOLDER | |
args.iteration = 1000 | |
os.makedirs(args.model_path, exist_ok=True) | |
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args) | |
################################################################################################################################################## | |
# ------ (3) Render video by interpolation ------ | |
parser = ArgumentParser(description="Testing script parameters") | |
model = ModelParams(parser, sentinel=True) | |
pipeline = PipelineParams(parser) | |
args.eval = True | |
args.get_video = True | |
args.n_views = opt.n_views | |
render_sets( | |
model.extract(args), | |
args.iteration, | |
pipeline.extract(args), | |
args.skip_train, | |
args.skip_test, | |
args, | |
) | |
output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply' | |
output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4' | |
# output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply' | |
# output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4' | |
return output_video_path, output_ply_path, output_ply_path | |
################################################################################################################################################## | |
_TITLE = '''InstantSplat''' | |
_DESCRIPTION = ''' | |
<div style="display: flex; justify-content: center; align-items: center;"> | |
<div style="width: 100%; text-align: center; font-size: 30px;"> | |
<strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong> | |
</div> | |
</div> | |
<p></p> | |
<div align="center"> | |
<a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a> | |
<a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a> | |
<a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a> | |
</div> | |
<p></p> | |
* Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/). | |
* Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data. | |
* Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090). | |
''' | |
# <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a> | |
# | |
# <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a> | |
# * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a> | |
# block = gr.Blocks(title=_TITLE).queue() | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# gr.Markdown('# ' + _TITLE) | |
gr.Markdown(_DESCRIPTION) | |
with gr.Row(variant='panel'): | |
with gr.Tab("Input"): | |
inputfiles = gr.File(file_count="multiple", label="images") | |
input_path = gr.Textbox(visible=False, label="example_path") | |
button_gen = gr.Button("RUN") | |
with gr.Row(variant='panel'): | |
with gr.Tab("Output"): | |
with gr.Column(scale=2): | |
output_model = gr.Model3D( | |
label="3D Model (Gaussian)", | |
# height=300, | |
interactive=False, | |
# clear_color=[1.0, 1.0, 1.0, 1.0] | |
) | |
output_file = gr.File(label="ply") | |
with gr.Column(scale=1): | |
output_video = gr.Video(label="video") | |
button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model]) | |
gr.Examples( | |
examples=[ | |
"sora-santorini-3-views", | |
"TT-family-3-views", | |
"dl3dv-ba55-3-views", | |
], | |
inputs=[input_path], | |
outputs=[output_video, output_file, output_model], | |
fn=lambda x: process(inputfiles=None, input_path=x), | |
cache_examples=True, | |
label='Sparse-view Examples' | |
) | |
block.launch(server_name="0.0.0.0", share=False) |