File size: 4,014 Bytes
062b5a5
37aeb5b
 
 
 
5a3e910
 
 
37aeb5b
 
69ac8ac
 
 
 
 
 
 
 
 
37aeb5b
 
 
69ac8ac
37aeb5b
 
 
69ac8ac
 
 
 
 
 
 
 
37aeb5b
 
 
 
 
 
 
 
 
 
 
69ac8ac
37aeb5b
 
 
 
 
 
 
 
c07e086
37aeb5b
 
 
 
69ac8ac
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import spaces
import os
import gradio as gr
from PIL import Image
from pytorch3d.structures import Meshes
from gradio_app.utils import clean_up
from gradio_app.custom_models.mvimg_prediction import run_mvprediction
from gradio_app.custom_models.normal_prediction import predict_normals
from scripts.refine_lr_to_sr import run_sr_fast
from scripts.utils import save_glb_and_video
# from scripts.multiview_inference import geo_reconstruct
from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3

@spaces.GPU
def run_mv(preview_img, input_processing, seed):
    if preview_img.size[0] <= 512:
        preview_img = run_sr_fast([preview_img])[0]
    rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
    return rgb_pils, front_pil

def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
    if preview_img is None:
        raise gr.Error("The input image is none!")
    if isinstance(preview_img, str):
        preview_img = Image.open(preview_img)
    
    rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)

    vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)

    meshes = geo_reconstruct_part2(vertices, faces)

    new_meshes = geo_reconstruct_part3(meshes, img_list)

    vertices = new_meshes.verts_packed()
    vertices = vertices / 2 * 1.35
    vertices[..., [0, 2]] = - vertices[..., [0, 2]]
    new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
    
    ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
    return ret_mesh, video

#######################################
def create_ui(concurrency_id="wkl"):
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
            
            example_folder = os.path.join(os.path.dirname(__file__), "./examples")
            example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
            gr.Examples(
                examples=example_fns,
                inputs=[input_image],
                cache_examples=False,
                label='Examples',
                examples_per_page=12
            )
            

        with gr.Column(scale=1):
            # export mesh display
            output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
            output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
            
            input_processing = gr.Checkbox(
                value=True,
                label='Remove Background',
                visible=True,
            )
            do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
            expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
            init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
            setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
            render_video = gr.Checkbox(value=False, visible=False, label="generate video")
            fullrunv2_btn = gr.Button('Generate 3D', interactive=True)
            
    fullrunv2_btn.click(
        fn = generate3dv2,
        inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
        outputs=[output_mesh, output_video],
        concurrency_id=concurrency_id,
        api_name="generate3dv2",
    ).success(clean_up, api_name=False)
    return input_image