import spaces
import gradio as gr
import os
import shutil
from pathlib import Path
from toon3d.scripts.viser_vis import main as viser_vis_main
import viser
import time
import threading
viewer_thread_instance = None
stop_event = threading.Event()
shared_url = None
_HEADER_ = '''
Toon3D: Seeing Cartoons from a New Perspective
Toon3D lifts hand-drawn images to 3D with a piecewise-rigid deformation optimization at hand-labeled keypoints and using monocular depth as a prior. The project page is at https://toon3d.studio/ and the Toon3D Labeler is at https://labeler.toon3d.studio/. Follow the steps below to run Toon3D!
Here is a video for how to use our demo:
Prepare and Process Data
Upload images and click on "Process Data" to generate processed data.
Download the processed data.
Label Data
Upload the processed data and label points using the labeler ("Upload ZIP").
Click export and upload the points.json to the "Labeled Points" section.
Generate 3D Output
Click on "Run Toon3D" to run the structure from motion pipeline.
Download the output and inspect locally (point cloud, mesh, Nerfstudio dataset).
View in Web!
Click on "Open Viewer" to view the output in an interactive viewer powered by Viser.
Reach out if you have any questions!
'''
def check_input_images(input_images):
if input_images is None:
raise gr.Error("No images uploaded!")
@spaces.GPU(duration=180)
def process_images(input_images):
images_path = "/tmp/gradio/images"
processed_path = "/tmp/gradio/processed"
# remove the images_path folder
os.system(f"rm -rf {images_path}")
os.system(f"rm -rf {processed_path}")
# copy the uploaded images to the images_path folder
os.system(f"mkdir -p {images_path}")
os.system(f"mkdir -p {processed_path}")
for fileobj in input_images:
shutil.copyfile(fileobj.name, images_path + "/" + os.path.basename(fileobj.name))
# download SAM checkpoint
download_cmd = "tnd-download-data sam --save-dir /tmp/gradio"
os.system(download_cmd)
# process the data
process_data_cmd = f"tnd-process-data initialize --dataset toon3d-dataset --input_path {images_path} --data_prefix {processed_path} --sam_checkpoint_prefix /tmp/gradio/sam-checkpoints"
os.system(process_data_cmd)
zip_folder = "/tmp/gradio/processed/toon3d-dataset"
shutil.make_archive(zip_folder, 'zip', zip_folder)
return zip_folder + ".zip"
def toggle_labeler_visibility(visible):
if visible:
return ''
else:
return ""
def check_input_toon3d(processed_data_zip, labeled_data):
if processed_data_zip is None:
raise gr.Error("No images uploaded!")
if labeled_data is None:
raise gr.Error("No labeled points uploaded!")
@spaces.GPU(duration=180)
def run_toon3d(processed_data_zip, labeled_data):
data_prefix = "/tmp/gradio/inputs"
processed_path = f"{data_prefix}/toon3d-dataset"
output_prefix = "/tmp/gradio/outputs"
nerfstudio_folder = "/tmp/gradio/nerfstudio"
os.system(f"rm -rf {processed_path}")
os.system(f"rm -rf {output_prefix}")
os.system(f"rm -rf {nerfstudio_folder}")
shutil.unpack_archive(processed_data_zip.name, processed_path)
shutil.copyfile(labeled_data.name, f"{processed_path}/points.json")
# run toon3d
toon3d_cmd = f"tnd-run --dataset toon3d-dataset --data_prefix {data_prefix} --output_prefix {output_prefix} --nerfstudio_folder {nerfstudio_folder} --no-view-point-cloud"
os.system(toon3d_cmd)
# get the last timestamped folder in output_prefix
# output_folder = sorted([f.path for f in os.scandir(output_prefix) if f.is_dir()])[-1]
output_dirs = Path(output_prefix) / "toon3d-dataset" / "run"
output_dir = Path(output_dirs / sorted(os.listdir(output_dirs))[-1])
zip_folder = str(output_dir)
shutil.make_archive(zip_folder, 'zip', zip_folder)
return zip_folder + ".zip"
# def open_viewer_fn(processed_data_zip, labeled_data, toon3d_output_zip):
# print(processed_data_zip)
# print(labeled_data)
# print(toon3d_output_zip)
# data_prefix = Path("/tmp/gradio/inputs")
# processed_path = f"{data_prefix}/toon3d-dataset"
# # extract the zip file
# viewer_folder = "/tmp/gradio/viewer/toon3d-dataset/run/temp"
# os.system(f"rm -rf {viewer_folder}")
# shutil.unpack_archive(toon3d_output_zip.name, viewer_folder)
# shutil.unpack_archive(processed_data_zip.name, processed_path)
# shutil.copyfile(labeled_data.name, f"{processed_path}/points.json")
# viser_server = viser.ViserServer()
# url = viser_server.request_share_url()
# print(url)
# # this is an infinite while loop so needs to be run in a separate thread
# # TODO:
# viser_vis_main(
# data_prefix=data_prefix,
# dataset="toon3d-dataset",
# output_prefix=Path("/tmp/gradio/viewer"),
# output_method=Path("run"),
# server=viser_server,
# visible=True,
# )
def viewer_thread(processed_data_zip, labeled_data, toon3d_output_zip):
global shared_url
data_prefix = Path("/tmp/gradio/inputs")
processed_path = f"{data_prefix}/toon3d-dataset"
viewer_folder = "/tmp/gradio/viewer/toon3d-dataset/run/temp"
os.system(f"rm -rf {viewer_folder}")
shutil.unpack_archive(toon3d_output_zip.name, viewer_folder)
shutil.unpack_archive(processed_data_zip.name, processed_path)
shutil.copyfile(labeled_data.name, f"{processed_path}/points.json")
viser_server = viser.ViserServer()
url = viser_server.request_share_url()
shared_url = url # Save the URL to the global variable
print(url)
viser_vis_main(
data_prefix=data_prefix,
dataset="toon3d-dataset",
output_prefix=Path("/tmp/gradio/viewer"),
output_method=Path("run"),
server=viser_server,
visible=True,
return_early=True
)
while not stop_event.is_set():
time.sleep(1)
viser_server.stop() # Ensure the server is stopped when the loop exits
def kill_viewer():
global viewer_thread_instance, stop_event
if viewer_thread_instance and viewer_thread_instance.is_alive():
stop_event.set() # Signal the thread to stop
viewer_thread_instance.join() # Wait for the thread to actually stop
viewer_thread_instance = None
print("Viewer has been stopped.")
else:
print("No viewer is running.")
def get_html_for_shared_url(url):
return f'
'
def check_input_open_viewer(processed_data_zip, labeled_data, toon3d_output_zip):
if processed_data_zip is None:
raise gr.Error("No processed data uploaded!")
if labeled_data is None:
raise gr.Error("No labeled points uploaded!")
if toon3d_output_zip is None:
raise gr.Error("No Toon3D output uploaded!")
def start_viewer(processed_data_zip, labeled_data, toon3d_output_zip):
kill_viewer() # Kill the existing viewer if it's running
global viewer_thread_instance, stop_event, shared_url
stop_event.clear() # Reset the stop event
shared_url = None # Reset the URL before starting
if viewer_thread_instance is None or not viewer_thread_instance.is_alive():
viewer_thread_instance = threading.Thread(target=viewer_thread, args=(processed_data_zip, labeled_data, toon3d_output_zip))
viewer_thread_instance.start()
while not shared_url:
# Wait for the URL to be set by the thread
time.sleep(0.1)
return get_html_for_shared_url(shared_url) # Return the URL after the thread has set it
else:
print("Viewer is already running.")
return get_html_for_shared_url(shared_url) # Return the current URL if the viewer is already running
with gr.Blocks(title="Toon3D") as demo:
gr.Markdown(_HEADER_)
with gr.Row(variant="panel"):
input_images = gr.File(label="Upload Images", file_count="multiple", file_types=[".jpg", "jpeg", "png"])
process_data_button = gr.Button("Process Data", elem_id="process_data_button", variant="primary")
processed_data_zip = gr.File(label="Processed Data", file_count="single", file_types=[".zip"], interactive=True)
with gr.Row(variant="panel"):
labeler_visible = gr.Checkbox(label="Show Labeler", value=False)
with gr.Row(variant="panel"):
labeler_frame = gr.HTML()
labeler_visible.change(toggle_labeler_visibility, inputs=[labeler_visible], outputs=[labeler_frame])
with gr.Row(variant="panel"):
labeled_data = gr.File(label="Labeled Points", file_count="single", file_types=[".json"])
run_toon3d_button = gr.Button("Run Toon3D", elem_id="run_toon3d_button", variant="primary")
toon3d_output_zip = gr.File(label="Toon3D Output", file_count="single", file_types=[".zip"], interactive=True)
with gr.Row(variant="panel"):
open_viewer_button = gr.Button("Open Viewer", elem_id="open_viser_button", variant="primary")
with gr.Row(variant="panel"):
viser_link = gr.HTML()
process_data_button.click(fn=check_input_images, inputs=[input_images]).success(
fn=process_images,
inputs=[input_images],
outputs=[processed_data_zip],
)
run_toon3d_button.click(fn=check_input_toon3d, inputs=[processed_data_zip, labeled_data]).success(
fn=run_toon3d,
inputs=[processed_data_zip, labeled_data],
outputs=[toon3d_output_zip],
)
open_viewer_button.click(fn=check_input_open_viewer, inputs=[processed_data_zip, labeled_data, toon3d_output_zip]).success(
fn=start_viewer,
inputs=[processed_data_zip, labeled_data, toon3d_output_zip],
outputs=[viser_link],
)
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(share=True)