Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
import numpy as np | |
import os | |
import json | |
import subprocess | |
from PIL import Image | |
from functools import partial | |
from datetime import datetime | |
from sam_inference import get_sam_predictor, sam_seg | |
from utils import blend_seg, blend_seg_pure | |
import cv2 | |
import uuid | |
import torch | |
import trimesh | |
from huggingface_hub import snapshot_download | |
from gradio_model3dcolor import Model3DColor | |
# from gradio_model3dnormal import Model3DNormal | |
code_dir = snapshot_download("sudo-ai/MeshFormer-API", token=os.environ['HF_TOKEN']) | |
with open(f'{code_dir}/api.json', 'r') as file: | |
api_dict = json.load(file) | |
SEG_CMD = api_dict["SEG_CMD"] | |
MESH_CMD = api_dict["MESH_CMD"] | |
STYLE = """ | |
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous"> | |
""" | |
# info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle) | |
ICONS = { | |
"info": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-info-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16zm.93-9.412-1 4.705c-.07.34.029.533.304.533.194 0 .487-.07.686-.246l-.088.416c-.287.346-.92.598-1.465.598-.703 0-1.002-.422-.808-1.319l.738-3.468c.064-.293.006-.399-.287-.47l-.451-.081.082-.381 2.29-.287zM8 5.5a1 1 0 1 1 0-2 1 1 0 0 1 0 2z"/> | |
</svg>""", | |
"cursor": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-hand-index-thumb-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M8.5 1.75v2.716l.047-.002c.312-.012.742-.016 1.051.046.28.056.543.18.738.288.273.152.456.385.56.642l.132-.012c.312-.024.794-.038 1.158.108.37.148.689.487.88.716.075.09.141.175.195.248h.582a2 2 0 0 1 1.99 2.199l-.272 2.715a3.5 3.5 0 0 1-.444 1.389l-1.395 2.441A1.5 1.5 0 0 1 12.42 16H6.118a1.5 1.5 0 0 1-1.342-.83l-1.215-2.43L1.07 8.589a1.517 1.517 0 0 1 2.373-1.852L5 8.293V1.75a1.75 1.75 0 0 1 3.5 0z"/> | |
</svg>""", | |
"wait": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-hourglass-split flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M2.5 15a.5.5 0 1 1 0-1h1v-1a4.5 4.5 0 0 1 2.557-4.06c.29-.139.443-.377.443-.59v-.7c0-.213-.154-.451-.443-.59A4.5 4.5 0 0 1 3.5 3V2h-1a.5.5 0 0 1 0-1h11a.5.5 0 0 1 0 1h-1v1a4.5 4.5 0 0 1-2.557 4.06c-.29.139-.443.377-.443.59v.7c0 .213.154.451.443.59A4.5 4.5 0 0 1 12.5 13v1h1a.5.5 0 0 1 0 1h-11zm2-13v1c0 .537.12 1.045.337 1.5h6.326c.216-.455.337-.963.337-1.5V2h-7zm3 6.35c0 .701-.478 1.236-1.011 1.492A3.5 3.5 0 0 0 4.5 13s.866-1.299 3-1.48V8.35zm1 0v3.17c2.134.181 3 1.48 3 1.48a3.5 3.5 0 0 0-1.989-3.158C8.978 9.586 8.5 9.052 8.5 8.351z"/> | |
</svg>""", | |
"done": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-check-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
<path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-3.97-3.03a.75.75 0 0 0-1.08.022L7.477 9.417 5.384 7.323a.75.75 0 0 0-1.06 1.06L6.97 11.03a.75.75 0 0 0 1.079-.02l3.992-4.99a.75.75 0 0 0-.01-1.05z"/> | |
</svg>""", | |
} | |
icons2alert = { | |
"info": "primary", # blue | |
"cursor": "info", # light blue | |
"wait": "secondary", # gray | |
"done": "success", # green | |
} | |
def message(text, icon_type="info"): | |
return f"""{STYLE} <div class="alert alert-{icons2alert[icon_type]} d-flex align-items-center" role="alert"> {ICONS[icon_type]} | |
<div> | |
{text} | |
</div> | |
</div>""" | |
def preprocess(tmp_dir, input_img, idx=None): | |
if idx is not None: | |
print("image idx:", int(idx)) | |
input_img = Image.open(input_img[int(idx)]["name"]) | |
input_img.save(f"{tmp_dir}/input.png") | |
# print(SEG_CMD.format(tmp_dir=tmp_dir)) | |
os.system(SEG_CMD.format(tmp_dir=tmp_dir)) | |
processed_img = Image.open(f"{tmp_dir}/seg.png") | |
return processed_img.resize((320, 320), Image.Resampling.LANCZOS) | |
def ply_to_glb(ply_path): | |
result = subprocess.run( | |
["python", "ply2glb.py", "--", ply_path], | |
capture_output=True, | |
text=True, | |
) | |
print("Output of blender script:") | |
print(result.stdout) | |
glb_path = ply_path.replace(".ply", ".glb") | |
return glb_path | |
def mesh_gen(tmp_dir, simplify, num_inference_steps): | |
# print(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) | |
os.system(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) | |
mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply") | |
vertex_normals = mesh.vertex_normals | |
theta = np.radians(180) # Rotation angle in radians | |
# Create rotation matrix | |
cos_theta = np.cos(theta) | |
sin_theta = np.sin(theta) | |
rotation_matrix = np.array([ | |
[cos_theta, -sin_theta, 0], | |
[sin_theta, cos_theta, 0], | |
[0, 0, 1] | |
]) | |
rotated_normal = np.dot(vertex_normals, rotation_matrix.T) | |
# rotated_normal = rotated_normal / np.linalg.norm(rotated_normal) | |
colors = (-rotated_normal + 1) / 2.0 | |
# colors = (-vertex_normals + 1) / 2.0 | |
colors = (colors * 255).clip(0, 255).astype(np.uint8) # Convert to 8-bit color | |
# print(colors.shape) | |
mesh.visual.vertex_colors = colors[..., [2, 1, 0]] # RGB -> BGR | |
mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply") | |
color_path = ply_to_glb(f"{tmp_dir}/mesh.ply") | |
normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply") | |
return color_path, normal_path | |
def create_tmp_dir(): | |
tmp_dir = ( | |
"demo_exp/" | |
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
+ "_" | |
+ str(uuid.uuid4())[:4] | |
) | |
os.makedirs(tmp_dir, exist_ok=True) | |
print("create tmp_exp_dir", tmp_dir) | |
return tmp_dir | |
def vis_seg(checkbox): | |
if checkbox: | |
print("Show manual seg windows") | |
return ( | |
[gr.Image(value=None, visible=True)] * 2 | |
+ [gr.Radio(visible=True)] | |
+ [[], gr.Checkbox(visible=True)] | |
) | |
else: | |
print("Clear manual seg") | |
return ( | |
[gr.Image(visible=False)] * 2 | |
+ [gr.Radio(visible=False)] | |
+ [[], gr.Checkbox(visible=False)] | |
) | |
def calc_feat(checkbox, predictor, input_image, idx=None): | |
if checkbox: | |
if idx is not None: | |
print("image idx:", int(idx)) | |
input_image = Image.open(input_image[int(idx)]["name"]) | |
input_image.thumbnail([512, 512], Image.Resampling.LANCZOS) | |
w, h = input_image.size | |
print("image size:", w, h) | |
side_len = np.max((w, h)) | |
seg_in = Image.new(input_image.mode, (side_len, side_len), (255, 255, 255)) | |
seg_in.paste( | |
input_image, (np.max((0, (h - w) // 2)), np.max((0, (w - h) // 2))) | |
) | |
print("Calculating image SAM feature...") | |
predictor.set_image(np.array(seg_in.convert("RGB"))) | |
torch.cuda.empty_cache() | |
return gr.Image(value=seg_in, visible=True) | |
else: | |
print("Quit manual seg") | |
raise ValueError("Quit manual seg") | |
def manual_seg( | |
predictor, | |
seg_in, | |
selected_points, | |
fg_bg_radio, | |
tmp_dir, | |
seg_mask_opt, | |
evt: gr.SelectData, | |
): | |
print("Start segmentation") | |
selected_points.append( | |
{"coord": evt.index, "add_del": fg_bg_radio == "+ (add mask)"} | |
) | |
input_points = np.array([point["coord"] for point in selected_points]) | |
input_labels = np.array([point["add_del"] for point in selected_points]) | |
out_image = sam_seg( | |
predictor, np.array(seg_in.convert("RGB")), input_points, input_labels | |
) | |
# seg_in.save(f"{tmp_dir}/in.png") | |
# out_image.save(f"{tmp_dir}/out.png") | |
if seg_mask_opt: | |
segmentation = blend_seg_pure( | |
seg_in.convert("RGB"), out_image, input_points, input_labels | |
) | |
else: | |
segmentation = blend_seg( | |
seg_in.convert("RGB"), out_image, input_points, input_labels | |
) | |
# recenter and rescale | |
image_arr = np.array(out_image) | |
ret, mask = cv2.threshold( | |
np.array(out_image.split()[-1]), 0, 255, cv2.THRESH_BINARY | |
) | |
x, y, w, h = cv2.boundingRect(mask) | |
max_size = max(w, h) | |
ratio = 0.75 | |
side_len = int(max_size / ratio) | |
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) | |
center = side_len // 2 | |
padded_image[ | |
center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w | |
] = image_arr[y : y + h, x : x + w] | |
rgba = Image.fromarray(padded_image) | |
rgba.save(f"{tmp_dir}/seg.png") | |
torch.cuda.empty_cache() | |
return segmentation.resize((380, 380), Image.Resampling.LANCZOS), rgba.resize( | |
(320, 320), Image.Resampling.LANCZOS | |
) | |
custom_theme = gr.themes.Soft(primary_hue="blue").set( | |
button_secondary_background_fill="*neutral_100", | |
button_secondary_background_fill_hover="*neutral_200", | |
) | |
with gr.Blocks(title="MeshFormer Demo", css="style.css", theme=custom_theme) as demo: | |
with gr.Row(): | |
gr.Markdown( | |
"# MeshFormer: High-Quality Mesh Generation with 3D-Guided Reconstruction Model" | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"[Project Page](https://meshformer3d.github.io/) | [arXiv](https://arxiv.org/abs/TBD)" | |
) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
<div> | |
<b><em>Check out <a href="https://www.sudo.ai/3dgen">Hillbot (sudoAI)</a> for more details and advanced features.</em></b> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
guide_text_i2m = gr.HTML(message("Please input an image!"), visible=True) | |
tmp_dir_img = gr.State("./demo_exp/placeholder") | |
tmp_dir_txt = gr.State("./demo_exp/placeholder") | |
tmp_dir_3t3 = gr.State("./demo_exp/placeholder") | |
example_folder = os.path.join(os.path.dirname(__file__), "demo_examples") | |
example_fns = os.listdir(example_folder) | |
example_fns.sort() | |
img_examples = [ | |
os.path.join(example_folder, x) for x in example_fns | |
] # if x.endswith('.png') or x.endswith('.') | |
with gr.Row(variant="panel"): | |
with gr.Row(): | |
with gr.Column(scale=8): | |
input_image = gr.Image( | |
type="pil", | |
image_mode="RGBA", | |
height=320, | |
label="Input Image", | |
interactive=True, | |
) | |
gr.Examples( | |
examples=img_examples, | |
inputs=[input_image], | |
outputs=[input_image], | |
cache_examples=False, | |
label="Image Examples (Click one of the images below to start)", | |
examples_per_page=27, | |
) | |
with gr.Accordion("Options", open=False): | |
img_simplify = gr.Checkbox( | |
False, label="simplify the generated mesh", visible=False | |
) | |
n_steps_img = gr.Slider( | |
value=28, | |
minimum=15, | |
maximum=100, | |
step=1, | |
label="number of inference steps", | |
) | |
# manual segmentation | |
checkbox_manual_seg = gr.Checkbox(False, label="manual segmentation") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
seg_in = gr.Image( | |
type="pil", | |
image_mode="RGBA", | |
label="Click to segment", | |
visible=False, | |
show_download_button=False, | |
height=380, | |
) | |
with gr.Column(scale=1): | |
seg_out = gr.Image( | |
type="pil", | |
image_mode="RGBA", | |
label="Segmentation", | |
interactive=False, | |
visible=False, | |
show_download_button=False, | |
height=380, | |
elem_id="disp_image", | |
) | |
fg_bg_radio = gr.Radio( | |
["+ (add mask)", "- (remove area)"], | |
value="+ (add mask)", | |
info="Select foreground (+) or background (-) point", | |
label="Point label", | |
visible=False, | |
interactive=True, | |
) | |
seg_mask_opt = gr.Checkbox( | |
True, | |
label="show foreground mask in manual segmentation", | |
visible=False, | |
) | |
# run | |
img_run_btn = gr.Button( | |
"Generate", variant="primary", interactive=False | |
) | |
with gr.Column(scale=6): | |
processed_image = gr.Image( | |
type="pil", | |
label="Processed Image", | |
interactive=False, | |
height=320, | |
image_mode="RGBA", | |
elem_id="disp_image", | |
) | |
# with gr.Row(): | |
# mesh_output = gr.Model3D(label="Generated Mesh", elem_id="model-3d-out") | |
mesh_output_normal = Model3DColor( | |
label="Generated Mesh (normal)", | |
elem_id="mesh-normal-out", | |
height=400, | |
) | |
mesh_output = Model3DColor( | |
label="Generated Mesh (color)", | |
elem_id="mesh-out", | |
height=400, | |
) | |
predictor = gr.State(value=get_sam_predictor()) | |
selected_points = gr.State(value=[]) | |
selected_points_t2i = gr.State(value=[]) | |
disable_checkbox = lambda: gr.Checkbox(value=False) | |
disable_button = lambda: gr.Button(interactive=False) | |
enable_button = lambda: gr.Button(interactive=True) | |
update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML( | |
value=message(GUIDE_TEXT, icon_type) | |
) | |
update_md = lambda GUIDE_TEXT: gr.Markdown(value=GUIDE_TEXT) | |
def is_img_clear(input_image): | |
if not input_image: | |
raise ValueError("Input image cleared.") | |
checkbox_manual_seg.change( | |
vis_seg, | |
inputs=[checkbox_manual_seg], | |
outputs=[seg_in, seg_out, fg_bg_radio, selected_points, seg_mask_opt], | |
queue=False, | |
).success( | |
calc_feat, | |
inputs=[checkbox_manual_seg, predictor, input_image], | |
outputs=[seg_in], | |
).success( | |
fn=create_tmp_dir, outputs=[tmp_dir_img], queue=False | |
) | |
seg_in.select( | |
manual_seg, | |
[predictor, seg_in, selected_points, fg_bg_radio, tmp_dir_img, seg_mask_opt], | |
[seg_out, processed_image], | |
) | |
input_image.change(disable_button, outputs=img_run_btn, queue=False).success( | |
disable_checkbox, outputs=checkbox_manual_seg, queue=False | |
).success(fn=is_img_clear, inputs=input_image, queue=False).success( | |
fn=create_tmp_dir, outputs=tmp_dir_img, queue=False | |
).success( | |
fn=partial(update_guide, "Preprocessing the image!", "wait"), | |
outputs=[guide_text_i2m], | |
queue=False, | |
).success( | |
fn=preprocess, | |
inputs=[tmp_dir_img, input_image], | |
outputs=[processed_image], | |
queue=True, | |
).success( | |
fn=partial( | |
update_guide, | |
"Click <b>Generate</b> to generate mesh! If the input image was not segmented accurately, please adjust it using <b>manual segmentation</b>.", | |
"cursor", | |
), | |
outputs=[guide_text_i2m], | |
queue=False, | |
).success( | |
enable_button, outputs=img_run_btn, queue=False | |
) | |
img_run_btn.click( | |
fn=partial(update_guide, "Generating the mesh!", "wait"), | |
outputs=[guide_text_i2m], | |
queue=False, | |
).success( | |
fn=mesh_gen, | |
inputs=[tmp_dir_img, img_simplify, n_steps_img], | |
outputs=[mesh_output, mesh_output_normal], | |
queue=True, | |
).success( | |
fn=partial( | |
update_guide, | |
"Successfully generated the mesh. (It might take a few seconds to load the mesh)", | |
"done", | |
), | |
outputs=[guide_text_i2m], | |
queue=False, | |
) | |
demo.queue().launch( | |
debug=True, share=False, inline=False, show_api=False, server_name="0.0.0.0" | |
) | |