ThomasSimonini HF staff commited on
Commit
881fc76
β€’
1 Parent(s): 3004373

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from functools import partial
8
+
9
+ import logging
10
+ import os
11
+ import shlex
12
+ import subprocess
13
+ import tempfile
14
+ import time
15
+
16
+ subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
17
+
18
+ from tsr.system import TSR
19
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
20
+
21
+
22
+ STEP1_HEADER = """
23
+ # Step 1: Generate the 3D Mesh
24
+
25
+ For this step, we use TripoSR, an open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
26
+
27
+ During this step, you need to upload an image of what you want to generate a 3D Model from.
28
+
29
+
30
+ ## πŸ’‘ Tips
31
+
32
+ - If there's a background, βœ… Remove background.
33
+
34
+ - If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
35
+
36
+
37
+ """
38
+
39
+ # These part of the code (check_input_image and preprocess were taken from https://huggingface.co/spaces/stabilityai/TripoSR/blob/main/app.py)
40
+ if torch.cuda.is_available():
41
+ device = "cuda:0"
42
+ else:
43
+ device = "cpu"
44
+
45
+ model = TSR.from_pretrained(
46
+ "stabilityai/TripoSR",
47
+ config_name="config.yaml",
48
+ weight_name="model.ckpt",
49
+ )
50
+ model.renderer.set_chunk_size(131072)
51
+ model.to(device)
52
+
53
+ rembg_session = rembg.new_session()
54
+
55
+
56
+ def check_input_image(input_image):
57
+ if input_image is None:
58
+ raise gr.Error("No image uploaded!")
59
+
60
+
61
+ def preprocess(input_image, do_remove_background, foreground_ratio):
62
+ def fill_background(image):
63
+ image = np.array(image).astype(np.float32) / 255.0
64
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
65
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
66
+ return image
67
+
68
+ if do_remove_background:
69
+ image = input_image.convert("RGB")
70
+ image = remove_background(image, rembg_session)
71
+ image = resize_foreground(image, foreground_ratio)
72
+ image = fill_background(image)
73
+ else:
74
+ image = input_image
75
+ if image.mode == "RGBA":
76
+ image = fill_background(image)
77
+ return image
78
+
79
+
80
+ @spaces.GPU
81
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
82
+ scene_codes = model(image, device=device)
83
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
84
+ mesh = to_gradio_3d_orientation(mesh)
85
+
86
+ mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
87
+ mesh.export(mesh_path_glb.name)
88
+
89
+ mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
90
+ mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
91
+ mesh.export(mesh_path_obj.name)
92
+
93
+ return mesh_path_obj.name, mesh_path_glb.name
94
+
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown(STEP1_HEADER)
98
+ with gr.Row(variant = "panel"):
99
+ with gr.Column():
100
+ with gr.Row():
101
+ input_image = gr.Image(
102
+ label = "Input Image",
103
+ image_mode = "RGBA",
104
+ sources = "upload",
105
+ type="pil",
106
+ elem_id="content_image")
107
+ processed_image = gr.Image(label="Processed Image", interactive=False)
108
+ with gr.Row():
109
+ with gr.Group():
110
+ do_remove_background = gr.Checkbox(
111
+ label="Remove Background",
112
+ value=True)
113
+ foreground_ratio = gr.Slider(
114
+ label="Foreground Ratio",
115
+ minimum=0.5,
116
+ maximum=1.0,
117
+ value=0.85,
118
+ step=0.05,
119
+ )
120
+ mc_resolution = gr.Slider(
121
+ label="Marching Cubes Resolution",
122
+ minimum=32,
123
+ maximum=320,
124
+ value=256,
125
+ step=32
126
+ )
127
+ with gr.Row():
128
+ step1_submit = gr.Button("Generate", elem_id="generate", variant="primary")
129
+
130
+ with gr.Column():
131
+ with gr.Tab("OBJ"):
132
+ output_model_obj = gr.Model3D(
133
+ label = "Output Model (OBJ Format)",
134
+ interative = False,
135
+ )
136
+ gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
137
+ with gr.Tab("GLB"):
138
+ output_model_glb = gr.Model3D(
139
+ label="Output Model (GLB Format)",
140
+ interactive=False,
141
+ )
142
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
143
+
144
+ step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
145
+ fn=preprocess,
146
+ inputs=[input_image, do_remove_background, foreground_ratio],
147
+ outputs=[processed_image],
148
+ ).success(
149
+ fn=generate,
150
+ inputs=[processed_image, mc_resolution],
151
+ outputs=[output_model_obj, output_model_glb],
152
+ )
153
+
154
+ demo.queue(max_size=10)
155
+ demo.launch()