File size: 7,793 Bytes
f80de23
 
 
 
 
 
 
 
 
 
 
 
7cbba90
 
f80de23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
import sys
import gradio as gr
import random
from configs.infer_config import get_parser
from huggingface_hub import hf_hub_download
sys.path.append('./extern/dust3r')
from dust3r.inference import inference, load_model
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from utils.diffusion_utils import instantiate_from_config,load_model_checkpoint,image_guided_synthesis
import torchvision.transforms as transforms
import copy

i2v_examples = [
    ['test/images/boy.png', 0, 1.0, '0 40', '0 0', '0 0',  50, 123],
    ['test/images/car.jpeg', 0, 1.0, '0 -35', '0 0', '0 -0.1',  50, 123],
    ['test/images/fruit.jpg', 0, 1.0, '0 -3 -15 -20 -17 -5 0', '0 -2 -5 -10 -8 -5 0 2 5 3 0', '0 0',  50, 123],
    ['test/images/room.png', 5, 1.0, '0 3 10 20 17 10 0', '0 -2 -8 -6 0 2 5 3 0', '0 -0.02 -0.09 -0.16 -0.09 0',  50, 123],
    ['test/images/castle.png', 0, 1.0, '0 30', '0 -1 -5 -4 0 1 5 4 0', '0 -0.2',  50, 123],
]

max_seed = 2 ** 31

def download_model():
    REPO_ID = 'Drexubery/ViewCrafter_25'
    filename_list = ['model.ckpt']
    for filename in filename_list:
        local_file = os.path.join('./checkpoints/', filename)
        if not os.path.exists(local_file):
            hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', force_download=True)
    
download_model()


css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
parser = get_parser() # infer_config.py
opts = parser.parse_args() # default device: 'cuda:0'
opts.save_dir = './'
os.makedirs(opts.save_dir,exist_ok=True)
test_tensor = torch.Tensor([0]).cuda()
opts.device = str(test_tensor.device)

dust3r = load_model(opts.model_path, opts.device)
config = OmegaConf.load(opts.config)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(model_config)
model = model.to(opts.device)
model.cond_stage_model.device = opts.device
model.perframe_ae = opts.perframe_ae
assert os.path.exists(opts.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, opts.ckpt_path)
model.eval()
diffusion = model
transform = transforms.Compose([
    transforms.Resize(576),
    transforms.CenterCrop((576,1024)),
    ])

def infer(opts,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_d_phi, i2v_d_theta, i2v_d_r, i2v_steps, i2v_seed):
    elevation = float(i2v_elevation)
    center_scale = float(i2v_center_scale)
    ddim_steps = i2v_steps
    gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()]
    seed_everything(i2v_seed)

    torch.cuda.empty_cache()
    img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
    img_tensor = (img_tensor / 255. - 0.5) * 2
    image_tensor_resized = transform(img_tensor) #1,3,h,w
    images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32)
    images = [images, copy.deepcopy(images)]
    images[1]['idx'] = 1
    se_images = images
    se_img_ori = (image_tensor_resized.squeeze(0).permute(1,2,0) + 1.)/2.

    run_dust3r(input_images=self.images)
    nvs_single_view(gradio=True)

    traj_dir = os.path.join(self.opts.save_dir, "viz_traj.mp4")
    gen_dir = os.path.join(self.opts.save_dir, "diffusion0.mp4")
    return i2v_traj_path,i2v_output_path

with gr.Blocks(analytics_enabled=False, css=css) as viewcrafter_iface:
    gr.Markdown("<div align='center'> <h1> ViewCrafter: Taming Video Diffusion Models for High-fidelity Novel View Synthesis </span> </h1> \
                    <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
                    <a href='https://scholar.google.com/citations?user=UOE8-qsAAAAJ&hl=zh-CN'>Wangbo Yu</a>, \
                    <a href='https://doubiiu.github.io/'>Jinbo Xing</a>, <a href=''>Li Yuan</a>, \
                    <a href='https://wbhu.github.io/'>Wenbo Hu</a>, <a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>,\
                    <a href=''>Zhipeng Huang</a>, <a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en/'>Xiangjun Gao</a>,\
                    <a href='https://www.cse.cuhk.edu.hk/~ttwong/myself.html/'>Tien-Tsin Wong</a>,\
                    <a href='https://scholar.google.com/citations?hl=en&user=4oXBp9UAAAAJ&view_op=list_works&sortby=pubdate/'>Ying Shan</a>\
                    <a href=''>Yonghong Tian</a>\
                </h2> \
                    <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Drexubery/ViewCrafter/blob/main/docs/render_help.md'> [Guideline] </a>\
                    <a style='font-size:18px;color: #000000' href=''> [ArXiv] </a>\
                    <a style='font-size:18px;color: #000000' href='https://drexubery.github.io/ViewCrafter/'> [Project Page] </a>\
                    <a style='font-size:18px;color: #000000' href='https://github.com/Drexubery/ViewCrafter'> [Github] </a> </div>") 
            
    #######image2video######
    with gr.Tab(label="ViewCrafter_25, 'single_view_txt' mode"):
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
                    with gr.Row():
                        i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
                    with gr.Row():
                        i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
                    with gr.Row():
                        i2v_d_phi = gr.Text(label='d_phi sequence, should start with 0')
                    with gr.Row():
                        i2v_d_theta = gr.Text(label='d_theta sequence, should start with 0')
                    with gr.Row():
                        i2v_d_r = gr.Text(label='d_r sequence, should start with 0')
                    with gr.Row():
                        i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
                    with gr.Row():
                        i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=max_seed, step=1, value=123)
                    i2v_end_btn = gr.Button("Generate")
                # with gr.Tab(label='Result'):
                with gr.Column():
                    with gr.Row():
                        i2v_traj_video = gr.Video(label="Camera Trajectory",elem_id="traj_vid",autoplay=True,show_share_button=True)
                    with gr.Row():
                        i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)

            gr.Examples(examples=i2v_examples,
                        inputs=[opts,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_d_phi, i2v_d_theta, i2v_d_r, i2v_steps, i2v_seed],
                        outputs=[i2v_traj_video,i2v_output_video],
                        fn = infer,
                        cache_examples=False,
            )

        # image2video.run_gradio(i2v_input_image='test/images/boy.png', i2v_elevation='10', i2v_d_phi='0 40', i2v_d_theta='0 0', i2v_d_r='0 0', i2v_center_scale=1, i2v_steps=50, i2v_seed=123)
        i2v_end_btn.click(inputs=[opts,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_d_phi, i2v_d_theta, i2v_d_r, i2v_steps, i2v_seed],
                        outputs=[i2v_traj_video,i2v_output_video],
                        fn = infer
        )

viewcrafter_iface.queue(max_size=12).launch(show_api=True)