dreamgaussian4d / app.py
jiaweir
change fps
4fd53d4
import gradio as gr
import os
from PIL import Image
import subprocess
from gradio_model4dgs import Model4DGS
import numpy
import hashlib
import shlex
subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
import rembg
import glob
import cv2
import numpy as np
from diffusers import StableVideoDiffusionPipeline
from scripts.gen_vid import *
import sys
sys.path.append('lgm')
from safetensors.torch import load_file
from kiui.cam import orbit_camera
from core.options import config_defaults, Options
from core.models import LGM
from mvdream.pipeline_mvdream import MVDreamPipeline
from infer_demo import process as process_lgm
from main_4d_demo import process as process_dg4d
import spaces
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors")
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
device = torch.device('cuda')
# # device = torch.device('cpu')
session = rembg.new_session(model_name='u2net')
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
pipe.to(device)
opt = config_defaults['big']
opt.resume = ckpt_path
# model
model = LGM(opt)
# resume pretrained checkpoint
if opt.resume is not None:
if opt.resume.endswith('safetensors'):
ckpt = load_file(opt.resume, device='cpu')
else:
ckpt = torch.load(opt.resume, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
print(f'[INFO] Loaded checkpoint from {opt.resume}')
else:
print(f'[WARN] model randomly initialized, are you sure?')
# device
model = model.half().to(device)
model.eval()
rays_embeddings = model.prepare_default_rays(device)
# load image dream
pipe_mvdream = MVDreamPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers", # remote weights
torch_dtype=torch.float16,
trust_remote_code=True,
# local_files_only=True,
)
pipe_mvdream = pipe_mvdream.to(device)
from guidance.zero123_utils import Zero123
guidance_zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers')
def preprocess(path, recenter=True, size=256, border_ratio=0.2):
files = [path]
out_dir = os.path.dirname(path)
for file in files:
out_base = os.path.basename(file).split('.')[0]
out_rgba = os.path.join(out_dir, out_base + '_rgba.png')
# load image
print(f'[INFO] loading image {file}...')
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
# carve background
print(f'[INFO] background removal...')
carved_image = rembg.remove(image, session=session) # [H, W, 4]
mask = carved_image[..., -1] > 0
# recenter
if recenter:
print(f'[INFO] recenter...')
final_rgba = np.zeros((size, size, 4), dtype=np.uint8)
coords = np.nonzero(mask)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
h = x_max - x_min
w = y_max - y_min
desired_size = int(size * (1 - border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (size - h2) // 2
x2_max = x2_min + h2
y2_min = (size - w2) // 2
y2_max = y2_min + w2
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
else:
final_rgba = carved_image
# write image
cv2.imwrite(out_rgba, final_rgba)
def gen_vid(input_path, seed, bg='white'):
name = input_path.split('/')[-1].split('.')[0]
input_dir = os.path.dirname(input_path)
height, width = 512, 512
image = load_image(input_path, width, height, bg)
generator = torch.manual_seed(seed)
# frames = pipe(image, height, width, decode_chunk_size=2, generator=generator).frames[0]
frames = pipe(image, height, width, generator=generator).frames[0]
imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=14)
os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True)
for idx, img in enumerate(frames):
img.save(f"{input_dir}/{name}_frames/{idx:03}.png")
# check if there is a picture uploaded or selected
def check_img_input(control_image):
if control_image is None:
raise gr.Error("Please select or upload an input image")
# check if there is a picture uploaded or selected
def check_video_3d_input(image_block: Image.Image):
if image_block is None:
raise gr.Error("Please select or upload an input image")
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
raise gr.Error("Please generate a video first")
if not os.path.exists(os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')):
raise gr.Error("Please generate a 3D first")
@spaces.GPU()
def optimize_stage_0(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
if not os.path.exists('tmp_data'):
os.makedirs('tmp_data')
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
if preprocess_chk:
# save image to a designated path
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
# preprocess image
# print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
# subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
else:
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
# stage 1
# subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
gen_vid(f'tmp_data/{img_hash}_rgba.png', seed_slider)
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
@spaces.GPU()
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
if not os.path.exists('tmp_data'):
os.makedirs('tmp_data')
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
if preprocess_chk:
# save image to a designated path
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
# preprocess image
# print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
# subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
else:
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
# stage 1
# subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings, seed_slider)
# return os.path.join('logs', f'{img_hash}_rgba_model.ply')
return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')
@spaces.GPU(duration=120)
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
# stage 2
# subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
# return os.path.join('vis_data', f'{img_hash}_rgba.mp4'), [image_dir+f'/{t:03d}.ply' for t in range(28)]
return [image_dir+f'/{t:03d}.ply' for t in range(28)]
if __name__ == "__main__":
_TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://jiawei-ren.github.io/projects/dreamgaussian4d/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2312.17142"><img src="https://img.shields.io/badge/2312.17142-f9f7f7?logo="></a>
<a style="display:inline-block; margin-left: .5em" href='https://github.com/jiawei-ren/dreamgaussian4d'><img src='https://img.shields.io/github/stars/jiawei-ren/dreamgaussian4d?style=social'/></a>
</div>
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
'''
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D** (they can run in parallel). Finally, click **Generate 4D**."
example_folder = os.path.join(os.path.dirname(__file__), 'data')
examples_full = [
[example_folder+'/panda.png', 40284],
[example_folder+'/csm_luigi_rgba.png', 10],
[example_folder+'/anya_rgba.png', 42],
]
# Compose demo layout & data flow
with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
# Image-to-3D
with gr.Row(variant='panel'):
with gr.Column(scale=5):
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (Video)')
seed_slider2 = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (3D)')
gr.Markdown(
"random seed for video generation.")
preprocess_chk = gr.Checkbox(True,
label='Preprocess image automatically (remove background and recenter object)')
with gr.Row():
with gr.Column(scale=5):
img_run_btn = gr.Button("Generate Video")
with gr.Column(scale=5):
threed_run_btn = gr.Button("Generate 3D")
fourd_run_btn = gr.Button("Generate 4D")
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
gr.Examples(
examples=examples_full, # NOTE: elements must match inputs list!
inputs=[image_block, seed_slider],
outputs=[image_block],
cache_examples=False,
label='Examples (click one of the examples below to start)',
examples_per_page=40
)
with gr.Column(scale=5):
with gr.Row():
with gr.Column(scale=5):
dirving_video = gr.Video(label="video",height=290)
with gr.Column(scale=5):
obj3d = gr.Video(label="3D Model",height=290)
# obj3d = gr.Model3D(label="3D Model",height=290)
# video4d = gr.Video(label="4D Render",height=290)
obj4d = Model4DGS(label="4D Model", height=500, fps=28)
gr.Markdown("*Please refresh the page before a new run.*")
img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0,
inputs=[image_block,
preprocess_chk,
seed_slider],
outputs=[
dirving_video])
threed_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
inputs=[image_block,
preprocess_chk,
seed_slider2],
outputs=[
obj3d])
fourd_run_btn.click(check_video_3d_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
# demo.queue().launch(share=True)
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
demo.launch()