Spaces:
Runtime error
Runtime error
Commit
•
3044d56
1
Parent(s):
1e78183
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import spaces
|
2 |
|
3 |
import os
|
|
|
|
|
4 |
import imageio
|
5 |
import numpy as np
|
6 |
import torch
|
@@ -31,6 +33,13 @@ from huggingface_hub import hf_hub_download
|
|
31 |
|
32 |
import gradio as gr
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
###############################################################################
|
36 |
# Configuration for InstantMesh
|
@@ -212,6 +221,175 @@ def make3d(images):
|
|
212 |
return mesh_fpath, mesh_glb_fpath
|
213 |
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
###############################################################################
|
216 |
# Gradio
|
217 |
###############################################################################
|
@@ -318,7 +496,8 @@ with gr.Blocks() as demo:
|
|
318 |
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
319 |
with gr.Row():
|
320 |
gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
|
321 |
-
|
|
|
322 |
mv_images = gr.State()
|
323 |
|
324 |
step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
|
@@ -334,7 +513,58 @@ with gr.Blocks() as demo:
|
|
334 |
inputs=[mv_images],
|
335 |
outputs=[output_model_obj, output_model_glb]
|
336 |
)
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
gr.Markdown(STEP3_HEADER)
|
339 |
gr.Markdown(STEP4_HEADER)
|
340 |
|
|
|
1 |
import spaces
|
2 |
|
3 |
import os
|
4 |
+
import time
|
5 |
+
|
6 |
import imageio
|
7 |
import numpy as np
|
8 |
import torch
|
|
|
33 |
|
34 |
import gradio as gr
|
35 |
|
36 |
+
# Imports for MeshAnythingv2
|
37 |
+
from accelerate.utils import set_seed
|
38 |
+
from accelerate import Accelerator
|
39 |
+
from main import load_v2
|
40 |
+
from mesh_to_pc import process_mesh_to_pc
|
41 |
+
import matplotlib.pyplot as plt
|
42 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
43 |
|
44 |
###############################################################################
|
45 |
# Configuration for InstantMesh
|
|
|
221 |
return mesh_fpath, mesh_glb_fpath
|
222 |
|
223 |
|
224 |
+
###############################################################################
|
225 |
+
# Configuration for MeshAnythingv2
|
226 |
+
###############################################################################
|
227 |
+
model = load_v2()
|
228 |
+
device = torch.device('cuda')
|
229 |
+
accelerator = Accelerator(
|
230 |
+
mixed_precision="fp16",
|
231 |
+
)
|
232 |
+
model = accelerator.prepare(model)
|
233 |
+
model.eval()
|
234 |
+
print("Model loaded to device")
|
235 |
+
|
236 |
+
def wireframe_render(mesh):
|
237 |
+
views = [
|
238 |
+
(90, 20), (270, 20)
|
239 |
+
]
|
240 |
+
mesh.vertices = mesh.vertices[:, [0, 2, 1]]
|
241 |
+
|
242 |
+
bounding_box = mesh.bounds
|
243 |
+
center = mesh.centroid
|
244 |
+
scale = np.ptp(bounding_box, axis=0).max()
|
245 |
+
|
246 |
+
fig = plt.figure(figsize=(10, 10))
|
247 |
+
|
248 |
+
# Function to render and return each view as an image
|
249 |
+
def render_view(mesh, azimuth, elevation):
|
250 |
+
ax = fig.add_subplot(111, projection='3d')
|
251 |
+
ax.set_axis_off()
|
252 |
+
|
253 |
+
# Extract vertices and faces for plotting
|
254 |
+
vertices = mesh.vertices
|
255 |
+
faces = mesh.faces
|
256 |
+
|
257 |
+
# Plot faces
|
258 |
+
ax.add_collection3d(Poly3DCollection(
|
259 |
+
vertices[faces],
|
260 |
+
facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
|
261 |
+
edgecolors='k',
|
262 |
+
linewidths=0.5,
|
263 |
+
))
|
264 |
+
|
265 |
+
# Set limits and center the view on the object
|
266 |
+
ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
|
267 |
+
ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2)
|
268 |
+
ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2)
|
269 |
+
|
270 |
+
# Set view angle
|
271 |
+
ax.view_init(elev=elevation, azim=azimuth)
|
272 |
+
|
273 |
+
# Save the figure to a buffer
|
274 |
+
buf = io.BytesIO()
|
275 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
|
276 |
+
plt.clf()
|
277 |
+
buf.seek(0)
|
278 |
+
|
279 |
+
return Image.open(buf)
|
280 |
+
|
281 |
+
# Render each view and store in a list
|
282 |
+
images = [render_view(mesh, az, el) for az, el in views]
|
283 |
+
|
284 |
+
# Combine images horizontally
|
285 |
+
widths, heights = zip(*(i.size for i in images))
|
286 |
+
total_width = sum(widths)
|
287 |
+
max_height = max(heights)
|
288 |
+
|
289 |
+
combined_image = Image.new('RGBA', (total_width, max_height))
|
290 |
+
|
291 |
+
x_offset = 0
|
292 |
+
for img in images:
|
293 |
+
combined_image.paste(img, (x_offset, 0))
|
294 |
+
x_offset += img.width
|
295 |
+
|
296 |
+
# Save the combined image
|
297 |
+
save_path = f"combined_mesh_view_{int(time.time())}.png"
|
298 |
+
combined_image.save(save_path)
|
299 |
+
|
300 |
+
plt.close(fig)
|
301 |
+
return save_path
|
302 |
+
|
303 |
+
@spaces.GPU(duration=360)
|
304 |
+
def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
|
305 |
+
set_seed(sample_seed)
|
306 |
+
print("Seed value:", sample_seed)
|
307 |
+
|
308 |
+
input_mesh = trimesh.load(input_3d)
|
309 |
+
pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes)
|
310 |
+
pc_normal = pc_list[0] # 4096, 6
|
311 |
+
mesh = mesh_list[0]
|
312 |
+
vertices = mesh.vertices
|
313 |
+
|
314 |
+
pc_coor = pc_normal[:, :3]
|
315 |
+
normals = pc_normal[:, 3:]
|
316 |
+
|
317 |
+
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
|
318 |
+
# scale mesh and pc
|
319 |
+
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
|
320 |
+
vertices = vertices / (bounds[1] - bounds[0]).max()
|
321 |
+
mesh.vertices = vertices
|
322 |
+
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
|
323 |
+
pc_coor = pc_coor / (bounds[1] - bounds[0]).max()
|
324 |
+
|
325 |
+
mesh.merge_vertices()
|
326 |
+
mesh.update_faces(mesh.nondegenerate_faces())
|
327 |
+
mesh.update_faces(mesh.unique_faces())
|
328 |
+
mesh.remove_unreferenced_vertices()
|
329 |
+
mesh.fix_normals()
|
330 |
+
try:
|
331 |
+
if mesh.visual.vertex_colors is not None:
|
332 |
+
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
|
333 |
+
|
334 |
+
mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
|
335 |
+
else:
|
336 |
+
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
|
337 |
+
mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
|
338 |
+
except Exception as e:
|
339 |
+
print(e)
|
340 |
+
input_save_name = f"processed_input_{int(time.time())}.obj"
|
341 |
+
mesh.export(input_save_name)
|
342 |
+
input_render_res = wireframe_render(mesh)
|
343 |
+
|
344 |
+
pc_coor = pc_coor / np.abs(pc_coor).max() * 0.99 # input should be from -1 to 1
|
345 |
+
|
346 |
+
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
|
347 |
+
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
|
348 |
+
|
349 |
+
input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
|
350 |
+
print("Data loaded")
|
351 |
+
|
352 |
+
# with accelerator.autocast():
|
353 |
+
with accelerator.autocast():
|
354 |
+
outputs = model(input, do_sampling)
|
355 |
+
print("Model inference done")
|
356 |
+
recon_mesh = outputs[0]
|
357 |
+
|
358 |
+
valid_mask = torch.all(~torch.isnan(recon_mesh.reshape((-1, 9))), dim=1)
|
359 |
+
recon_mesh = recon_mesh[valid_mask] # nvalid_face x 3 x 3
|
360 |
+
vertices = recon_mesh.reshape(-1, 3).cpu()
|
361 |
+
vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
|
362 |
+
triangles = vertices_index.reshape(-1, 3)
|
363 |
+
|
364 |
+
artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
|
365 |
+
merge_primitives=True)
|
366 |
+
|
367 |
+
artist_mesh.merge_vertices()
|
368 |
+
artist_mesh.update_faces(artist_mesh.nondegenerate_faces())
|
369 |
+
artist_mesh.update_faces(artist_mesh.unique_faces())
|
370 |
+
artist_mesh.remove_unreferenced_vertices()
|
371 |
+
artist_mesh.fix_normals()
|
372 |
+
|
373 |
+
if artist_mesh.visual.vertex_colors is not None:
|
374 |
+
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
|
375 |
+
|
376 |
+
artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
|
377 |
+
else:
|
378 |
+
orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
|
379 |
+
artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
|
380 |
+
|
381 |
+
num_faces = len(artist_mesh.faces)
|
382 |
+
|
383 |
+
brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
|
384 |
+
face_colors = np.tile(brown_color, (num_faces, 1))
|
385 |
+
|
386 |
+
artist_mesh.visual.face_colors = face_colors
|
387 |
+
# add time stamp to avoid cache
|
388 |
+
save_name = f"output_{int(time.time())}.obj"
|
389 |
+
artist_mesh.export(save_name)
|
390 |
+
output_render = wireframe_render(artist_mesh)
|
391 |
+
return input_save_name, input_render_res, save_name, output_render
|
392 |
+
|
393 |
###############################################################################
|
394 |
# Gradio
|
395 |
###############################################################################
|
|
|
496 |
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
497 |
with gr.Row():
|
498 |
gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
|
499 |
+
|
500 |
+
|
501 |
mv_images = gr.State()
|
502 |
|
503 |
step1_submit.click(fn=check_input_image, inputs=[input_image]).success(
|
|
|
513 |
inputs=[mv_images],
|
514 |
outputs=[output_model_obj, output_model_glb]
|
515 |
)
|
516 |
+
|
517 |
+
gr.Markdown(STEP2_HEADER)
|
518 |
+
with gr.Row(variant="panel"):
|
519 |
+
with gr.Column():
|
520 |
+
with gr.Row():
|
521 |
+
input_3d = gr.Model3D(
|
522 |
+
label="Input Mesh",
|
523 |
+
display_mode="wireframe",
|
524 |
+
clear_color=[1,1,1,1],
|
525 |
+
)
|
526 |
+
|
527 |
+
with gr.Row():
|
528 |
+
with gr.Group():
|
529 |
+
do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False)
|
530 |
+
do_sampling = gr.Checkbox(label="Random Sampling", value=False)
|
531 |
+
sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
|
532 |
+
|
533 |
+
with gr.Row():
|
534 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
535 |
+
|
536 |
+
with gr.Row(variant="panel"):
|
537 |
+
mesh_examples = gr.Examples(
|
538 |
+
examples=[
|
539 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
540 |
+
],
|
541 |
+
inputs=input_3d,
|
542 |
+
outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
|
543 |
+
fn=do_inference,
|
544 |
+
cache_examples = False,
|
545 |
+
examples_per_page=10
|
546 |
+
)
|
547 |
+
|
548 |
+
with gr.Column():
|
549 |
+
with gr.Row():
|
550 |
+
input_image_render.render()
|
551 |
+
with gr.Row():
|
552 |
+
with gr.Tab("OBJ"):
|
553 |
+
preprocess_model_obj.render()
|
554 |
+
with gr.Row():
|
555 |
+
output_image_render.render()
|
556 |
+
with gr.Row():
|
557 |
+
with gr.Tab("OBJ"):
|
558 |
+
output_model_obj.render()
|
559 |
+
with gr.Row():
|
560 |
+
gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''')
|
561 |
+
|
562 |
+
submit.click(
|
563 |
+
fn=do_inference,
|
564 |
+
inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
|
565 |
+
outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
|
566 |
+
)
|
567 |
+
|
568 |
gr.Markdown(STEP3_HEADER)
|
569 |
gr.Markdown(STEP4_HEADER)
|
570 |
|