3DFuse / gradio_app.py
jyseo's picture
first commit
d661b19
raw
history blame
14.6 kB
import torch
import argparse
import gradio as gr
from functools import partial
from my.config import BaseConf, dispatch_gradio
from run_3DFuse import SJC_3DFuse
import numpy as np
from PIL import Image
from pc_project import point_e
from diffusers import UnCLIPPipeline, DiffusionPipeline
from pc_project import point_e_gradio
import numpy as np
import plotly.graph_objs as go
from my.utils.seed import seed_everything
import os
SHARED_UI_WARNING = f'''### [NOTE] Training may be very slow in this shared UI.
You can duplicate and use it with a paid private GPU.
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>
Alternatively, you can also use the Colab demo on our project page.
<a style="display:inline-block" href="https://ku-cvlab.github.io/3DFuse/"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/Project%20Page-online-brightgreen"></a>
'''
class Intermediate:
def __init__(self):
self.images = None
self.points = None
self.is_generating = False
def gen_3d(model, intermediate, prompt, keyword, seed, ti_step, pt_step) :
intermediate.is_generating = True
images, points = intermediate.images, intermediate.points
if images is None or points is None :
raise gr.Error("Please generate point cloud first")
del model
seed_everything(seed)
model = dispatch_gradio(SJC_3DFuse, prompt, keyword, ti_step, pt_step, seed)
setting = model.dict()
# exp_dir = os.path.join(setting['exp_dir'],keyword)
# initial_images_dir = os.path.join(exp_dir, 'initial_image')
# os.makedirs(initial_images_dir,exist_ok=True)
# for idx,img in enumerate(images) :
# img.save( os.path.join(initial_images_dir, f"instance{idx}.png") )
yield from model.run_gradio(points, images)
intermediate.is_generating = False
def gen_pc_from_prompt(intermediate, num_initial_image, prompt, keyword, type, bg_preprocess, seed) :
seed_everything(seed=seed)
if keyword not in prompt:
raise gr.Error("Prompt should contain keyword!")
elif " " in keyword:
raise gr.Error("Keyword should be one word!")
images = gen_init(num_initial_image=num_initial_image, prompt=prompt,seed=seed, type=type, bg_preprocess=bg_preprocess)
points = point_e_gradio(images[0],'cuda')
intermediate.images = images
intermediate.points = points
coords = np.array(points.coords)
trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2))
layout = go.Layout(
scene=dict(
xaxis=dict(
title="",
showgrid=False,
zeroline=False,
showline=False,
ticks='',
showticklabels=False
),
yaxis=dict(
title="",
showgrid=False,
zeroline=False,
showline=False,
ticks='',
showticklabels=False
),
zaxis=dict(
title="",
showgrid=False,
zeroline=False,
showline=False,
ticks='',
showticklabels=False
),
),
margin=dict(l=0, r=0, b=0, t=0),
showlegend=False
)
fig = go.Figure(data=[trace], layout=layout)
return images[0], fig, gr.update(interactive=True)
def gen_pc_from_image(intermediate, image, prompt, keyword, bg_preprocess, seed) :
seed_everything(seed=seed)
if keyword not in prompt:
raise gr.Error("Prompt should contain keyword!")
elif " " in keyword:
raise gr.Error("Keyword should be one word!")
if bg_preprocess:
import cv2
from carvekit.api.high import HiInterface
interface = HiInterface(object_type="object",
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
img_without_background = interface([image])
mask = np.array(img_without_background[0]) > 127
image = np.array(image)
image[~mask] = [255., 255., 255.]
image = Image.fromarray(np.array(image))
points = point_e_gradio(image,'cuda')
intermediate.images = [image]
intermediate.points = points
coords = np.array(points.coords)
trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2))
layout = go.Layout(
scene=dict(
xaxis=dict(
title="",
showgrid=False,
zeroline=False,
showline=False,
ticks='',
showticklabels=False
),
yaxis=dict(
title="",
showgrid=False,
zeroline=False,
showline=False,
ticks='',
showticklabels=False
),
zaxis=dict(
title="",
showgrid=False,
zeroline=False,
showline=False,
ticks='',
showticklabels=False
),
),
margin=dict(l=0, r=0, b=0, t=0),
showlegend=False
)
fig = go.Figure(data=[trace], layout=layout)
return image, fig, gr.update(interactive=True)
def gen_init(num_initial_image, prompt,seed,type="Karlo", bg_preprocess=False):
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) if type=="Karlo (Recommended)" \
else DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to('cuda')
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "]
if bg_preprocess:
import cv2
from carvekit.api.high import HiInterface
interface = HiInterface(object_type="object",
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
images = []
generator = torch.Generator(device='cuda').manual_seed(seed)
for i in range(num_initial_image):
t=", white background" if bg_preprocess else ", white background"
if i==0:
prompt_ = f"{view_prompt[i%4]}{prompt}{t}"
else:
prompt_ = f"{view_prompt[i%4]}{prompt}"
image = pipe(prompt_, generator=generator).images[0]
if bg_preprocess:
# motivated by NeuralLift-360 (removing bg)
# NOTE: This option was added during the code orgranization process.
# The results reported in the paper were obtained with [bg_preprocess: False] setting.
img_without_background = interface([image])
mask = np.array(img_without_background[0]) > 127
image = np.array(image)
image[~mask] = [255., 255., 255.]
image = Image.fromarray(np.array(image))
images.append(image)
return images
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--share', action='store_true', help="public url")
args = parser.parse_args()
model = None
intermediate = Intermediate()
demo = gr.Blocks(title="3DFuse Interactive Demo")
with demo:
with gr.Box():
gr.Markdown(SHARED_UI_WARNING)
gr.Markdown("# 3DFuse Interactive Demo")
gr.Markdown("### Official Implementation of the paper \"Let 2D Diffusion Model Know 3D-Consistency for Robust Text-to-3D Generation\"")
gr.Markdown("Enter your own prompt and enjoy! With this demo, you can preview the point cloud before 3D generation and determine the desired shape.")
# gr.Markdown("Enter your own prompt and enjoy! With this demo, you can preview the point cloud before 3D generation and determine the desired shape.")
with gr.Row():
with gr.Column(scale=1., variant='panel'):
with gr.Tab("Text to 3D"):
prompt_input = gr.Textbox(label="Prompt", value="a comfortable bed", interactive=True)
word_input = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="bed", interactive=True)
semantic_model_choice = gr.Radio(["Karlo (Recommended)","Stable Diffusion"], label="Backbone for initial image generation", value="Karlo (Recommended)", interactive=True)
seed = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True)
preprocess_choice = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=True)
with gr.Accordion("Advanced Options", open=False):
opt_step = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step')
pivot_step = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA')
with gr.Row():
button_gen_pc = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary')
button_gen_3d = gr.Button("2. Generate 3D", interactive=False, variant='primary')
with gr.Tab("Image to 3D"):
image_input = gr.Image(source='upload', type="pil", interactive=True)
prompt_input_2 = gr.Textbox(label="Prompt", value="a dog", interactive=True)
word_input_2 = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="dog", interactive=True)
seed_2 = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True)
preprocess_choice_2 = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=False)
with gr.Accordion("Advanced Options", open=False):
opt_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step')
pivot_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA')
with gr.Row():
button_gen_pc_2 = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary')
button_gen_3d_2 = gr.Button("2. Generate 3D", interactive=False, variant='primary')
gr.Markdown("Note: A photo showing the entire object in a front view is recommended. Also, our framework is not designed with the goal of single shot reconstruction, so it may be difficult to reflect the details of the given image.")
with gr.Row(scale=1.):
with gr.Column(scale=1.):
pc_plot = gr.Plot(label="Inferred point cloud")
with gr.Column(scale=1.):
init_output = gr.Image(label='Generated initial image', interactive=False)
# init_output.style(grid=1)
with gr.Column(scale=1., variant='panel'):
with gr.Row():
with gr.Column(scale=1.):
intermediate_output = gr.Image(label="Intermediate Output", interactive=False)
with gr.Column(scale=1.):
logs = gr.Textbox(label="logs", lines=8, max_lines=8, interactive=False)
with gr.Row():
video_result = gr.Video(label="Video result for generated 3D", interactive=False)
gr.Markdown("Note: Keyword is used for Textual Inversion. Please choose one important noun included in the prompt. This demo may be slower than directly running run_3DFuse.py.")
# functions
button_gen_pc.click(fn=partial(gen_pc_from_prompt,intermediate,4), inputs=[prompt_input, word_input, semantic_model_choice, \
preprocess_choice, seed], outputs=[init_output, pc_plot, button_gen_3d])
button_gen_3d.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input, word_input, seed, opt_step, pivot_step], \
outputs=[intermediate_output,logs,video_result])
button_gen_pc_2.click(fn=partial(gen_pc_from_image,intermediate), inputs=[image_input, prompt_input_2, word_input_2, \
preprocess_choice_2, seed_2], outputs=[init_output, pc_plot, button_gen_3d_2])
button_gen_3d_2.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input_2, word_input_2, seed_2, opt_step_2, pivot_step_2], \
outputs=[intermediate_output,logs,video_result])
demo.queue(concurrency_count=1)
demo.launch(share=args.share)