File size: 15,494 Bytes
1a14066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import os
import wget
import subprocess
import sys
import torch

if os.getenv('SYSTEM') == 'spaces':
    pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
    version_str="".join([
        f"py3{sys.version_info.minor}_cu",
        torch.version.cuda.replace(".",""),
        f"_pyt{pyt_version_str}"
    ])
    # subprocess.run(
    #     'pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html'.split())
    # subprocess.run(
    #     'pip install https://download.is.tue.mpg.de/icon/HF/pytorch3d-0.7.0-cp38-cp38-linux_x86_64.whl'.split())
    subprocess.run(
        f'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html'.split())


import argparse
import gradio as gr
from functools import partial
from my.config import BaseConf, dispatch_gradio
from run_3DFuse import SJC_3DFuse
import numpy as np
from PIL import Image
from pc_project import point_e
from diffusers import UnCLIPPipeline, DiffusionPipeline
from pc_project import point_e_gradio
import numpy as np
import plotly.graph_objs as go
from my.utils.seed import seed_everything

SHARED_UI_WARNING = f'''### [NOTE]  Training may be very slow in this shared UI.
You can duplicate and use it with a paid private GPU.
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
Alternatively, you can also use the Colab demo on our project page.
<a style="display:inline-block" href="https://ku-cvlab.github.io/3DFuse/"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/Project%20Page-online-brightgreen"></a>
'''

class Intermediate:
    def __init__(self):
        self.images = None
        self.points = None
        self.is_generating = False


def gen_3d(model, intermediate, prompt, keyword, seed, ti_step, pt_step) :
    intermediate.is_generating = True
    images, points = intermediate.images, intermediate.points
    if images is None or points is None :
        raise gr.Error("Please generate point cloud first")
    del model
    
    seed_everything(seed)
    model = dispatch_gradio(SJC_3DFuse, prompt, keyword, ti_step, pt_step, seed)
    setting = model.dict()
    
    # exp_dir = os.path.join(setting['exp_dir'],keyword)
    # initial_images_dir = os.path.join(exp_dir, 'initial_image')
    # os.makedirs(initial_images_dir,exist_ok=True)    
    # for idx,img in enumerate(images) :
    #     img.save( os.path.join(initial_images_dir, f"instance{idx}.png") )
    
    yield from model.run_gradio(points, images)
    
    intermediate.is_generating = False
    


def gen_pc_from_prompt(intermediate, num_initial_image, prompt, keyword, type, bg_preprocess, seed) :
    
    seed_everything(seed=seed)
    if keyword not in prompt:
        raise gr.Error("Prompt should contain keyword!")
    elif " " in keyword:
        raise gr.Error("Keyword should be one word!")
    
    images = gen_init(num_initial_image=num_initial_image, prompt=prompt,seed=seed, type=type,  bg_preprocess=bg_preprocess)
    points = point_e_gradio(images[0],'cuda')
    
    intermediate.images = images
    intermediate.points = points
    
    coords = np.array(points.coords)
    trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2))

    layout = go.Layout(
        scene=dict(
            xaxis=dict(
                title="",
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
            yaxis=dict(
                title="",
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
            zaxis=dict(
                title="",
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
        ),
        margin=dict(l=0, r=0, b=0, t=0),
        showlegend=False
    )

    fig = go.Figure(data=[trace], layout=layout)
    
    return images[0], fig,  gr.update(interactive=True)


def gen_pc_from_image(intermediate, image, prompt, keyword, bg_preprocess, seed) :
    
    seed_everything(seed=seed)
    if keyword not in prompt:
        raise gr.Error("Prompt should contain keyword!")
    elif " " in keyword:
        raise gr.Error("Keyword should be one word!")
    
    if bg_preprocess:
        import cv2
        from carvekit.api.high import HiInterface
        interface = HiInterface(object_type="object",
                        batch_size_seg=5,
                        batch_size_matting=1,
                        device='cuda' if torch.cuda.is_available() else 'cpu',
                        seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
                        matting_mask_size=2048,
                        trimap_prob_threshold=231,
                        trimap_dilation=30,
                        trimap_erosion_iters=5,
                        fp16=False)
        
        img_without_background = interface([image])
        mask = np.array(img_without_background[0]) > 127
        image = np.array(image)
        image[~mask] = [255., 255., 255.]
        image = Image.fromarray(np.array(image))
    
    
    points = point_e_gradio(image,'cuda')
    
    intermediate.images = [image]
    intermediate.points = points
    
    coords = np.array(points.coords)
    trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2))

    layout = go.Layout(
        scene=dict(
            xaxis=dict(
                title="",
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
            yaxis=dict(
                title="",
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
            zaxis=dict(
                title="",
                showgrid=False,
                zeroline=False,
                showline=False,
                ticks='',
                showticklabels=False
            ),
        ),
        margin=dict(l=0, r=0, b=0, t=0),
        showlegend=False
    )

    fig = go.Figure(data=[trace], layout=layout)

    return image, fig, gr.update(interactive=True)

def gen_init(num_initial_image, prompt,seed,type="Karlo",  bg_preprocess=False):
    pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) if type=="Karlo (Recommended)" \
        else DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
    pipe = pipe.to('cuda')
    
    view_prompt=["front view of ","overhead view of ","side view of ", "back view of "]
    
    if bg_preprocess:
        import cv2
        from carvekit.api.high import HiInterface
        interface = HiInterface(object_type="object",
                        batch_size_seg=5,
                        batch_size_matting=1,
                        device='cuda' if torch.cuda.is_available() else 'cpu',
                        seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
                        matting_mask_size=2048,
                        trimap_prob_threshold=231,
                        trimap_dilation=30,
                        trimap_erosion_iters=5,
                        fp16=False)

    images = []
    generator = torch.Generator(device='cuda').manual_seed(seed)
    for i in range(num_initial_image):
        t=", white background" if bg_preprocess else ", white background"
        if i==0:
            prompt_ = f"{view_prompt[i%4]}{prompt}{t}"
        else:
            prompt_ = f"{view_prompt[i%4]}{prompt}"

        image = pipe(prompt_, generator=generator).images[0]
        
        if bg_preprocess:
            # motivated by NeuralLift-360 (removing bg)
            # NOTE: This option was added during the code orgranization process.
            # The results reported in the paper were obtained with [bg_preprocess: False] setting.
            img_without_background = interface([image])
            mask = np.array(img_without_background[0]) > 127
            image = np.array(image)
            image[~mask] = [255., 255., 255.]
            image = Image.fromarray(np.array(image))
        images.append(image)
            
    return images
            
        

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--share', action='store_true', help="public url")
    args = parser.parse_args()

    weights_dir = './weights'
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    weights_path = os.path.join(weights_dir, '3DFuse_sparse_depth_injector.ckpt')

    # 파일이 존재하지 않으면 wget으로 다운로드하여 저장
    if not os.path.isfile(weights_path):
        url = 'https://huggingface.co/jyseo/3DFuse_weights/resolve/main/models/3DFuse_sparse_depth_injector.ckpt'
        wget.download(url, weights_path)
        print(f'{weights_path} downloaded.')
    else:
        print(f'{weights_path} already exists.')

    
    model = None
    intermediate = Intermediate()
    demo = gr.Blocks(title="3DFuse Interactive Demo")
    
    with demo:
        with gr.Box():
            gr.Markdown(SHARED_UI_WARNING)
            
        gr.Markdown("# 3DFuse Interactive Demo")
        gr.Markdown("### Official Implementation of the paper \"Let 2D Diffusion Model Know 3D-Consistency for Robust Text-to-3D Generation\"")
        gr.Markdown("Enter your own prompt and enjoy! With this demo, you can preview the point cloud before 3D generation and determine the desired shape.")
        # gr.Markdown("Enter your own prompt and enjoy! With this demo, you can preview the point cloud before 3D generation and determine the desired shape.")
        with gr.Row():
            with gr.Column(scale=1., variant='panel'):
                
                with gr.Tab("Text to 3D"):
                    prompt_input = gr.Textbox(label="Prompt", value="a comfortable bed", interactive=True)
                    word_input = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="bed", interactive=True)
                    semantic_model_choice = gr.Radio(["Karlo (Recommended)","Stable Diffusion"], label="Backbone for initial image generation", value="Karlo (Recommended)", interactive=True)
                    seed = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True)
                    preprocess_choice = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=True)
                    with gr.Accordion("Advanced Options", open=False):
                        opt_step = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step')
                        pivot_step = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA')
                    with gr.Row():
                        button_gen_pc = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary')
                        button_gen_3d = gr.Button("2. Generate 3D", interactive=False, variant='primary')
                        
                with gr.Tab("Image to 3D"):
                    image_input = gr.Image(source='upload', type="pil", interactive=True)
                    prompt_input_2 = gr.Textbox(label="Prompt", value="a dog", interactive=True)
                    word_input_2 = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="dog", interactive=True)
                    seed_2 = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True)
                    preprocess_choice_2 = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=False)
                    with gr.Accordion("Advanced Options", open=False):
                        opt_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step')
                        pivot_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA')
                    with gr.Row():
                        button_gen_pc_2 = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary')
                        button_gen_3d_2 = gr.Button("2. Generate 3D", interactive=False, variant='primary')
                    gr.Markdown("Note: A photo showing the entire object in a front view is recommended. Also, our framework is not designed with the goal of single shot reconstruction, so it may be difficult to reflect the details of the given image.")
                    
                    
                with gr.Row(scale=1.):
                    with gr.Column(scale=1.):
                        pc_plot = gr.Plot(label="Inferred point cloud")
                    with gr.Column(scale=1.):
                        init_output = gr.Image(label='Generated initial image', interactive=False)
                        # init_output.style(grid=1)
                    
                        
            with gr.Column(scale=1., variant='panel'):
                with gr.Row():
                    with gr.Column(scale=1.):
                        intermediate_output = gr.Image(label="Intermediate Output", interactive=False)
                    with gr.Column(scale=1.):
                        logs = gr.Textbox(label="logs", lines=8, max_lines=8, interactive=False)
                with gr.Row():
                    video_result = gr.Video(label="Video result for generated 3D", interactive=False)
                    
        gr.Markdown("Note: Keyword is used for Textual Inversion. Please choose one important noun included in the prompt. This demo may be slower than directly running run_3DFuse.py.")
                    
                    
        # functions
        button_gen_pc.click(fn=partial(gen_pc_from_prompt,intermediate,4), inputs=[prompt_input, word_input, semantic_model_choice, \
            preprocess_choice, seed], outputs=[init_output, pc_plot, button_gen_3d])
        button_gen_3d.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input, word_input, seed, opt_step, pivot_step], \
            outputs=[intermediate_output,logs,video_result])
        
        button_gen_pc_2.click(fn=partial(gen_pc_from_image,intermediate), inputs=[image_input, prompt_input_2, word_input_2, \
            preprocess_choice_2, seed_2], outputs=[init_output, pc_plot, button_gen_3d_2])
        button_gen_3d_2.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input_2, word_input_2, seed_2, opt_step_2, pivot_step_2], \
            outputs=[intermediate_output,logs,video_result])
                
                            
    demo.queue(concurrency_count=1)                
    demo.launch(share=args.share)