from __future__ import annotations import functools import os import tempfile import torch import gradio as gr from PIL import Image from gradio_imageslider import ImageSlider from pathlib import Path from gradio.utils import get_cache_folder # Constants DEFAULT_SHARPNESS = 2 class Examples(gr.helpers.Examples): def __init__(self, *args, directory_name=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if directory_name is not None: self.cached_folder = get_cache_folder() / directory_name self.cached_file = Path(self.cached_folder) / "log.csv" self.create() def load_predictor(): """Load model predictor using torch.hub""" predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal", trust_repo=True) return predictor def process_image( predictor, path_input: str, sharpness: int = DEFAULT_SHARPNESS, data_type: str = "object" ) -> tuple: """Process single image""" if path_input is None: raise gr.Error("Please upload an image or select one from the gallery.") name_base = os.path.splitext(os.path.basename(path_input))[0] out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal.png") # Load and process image input_image = Image.open(path_input) normal_image = predictor(input_image, num_inference_steps=sharpness, match_input_resolution=False, data_type=data_type) normal_image.save(out_path) yield [input_image, out_path] def create_demo(): # Load model predictor = load_predictor() # Create processing functions for each data type process_object = functools.partial(process_image, predictor, data_type="object") process_scene = functools.partial(process_image, predictor, data_type="indoor") process_human = functools.partial(process_image, predictor, data_type="object") # Define markdown content HEADER_MD = """ # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal
""" # Create interface demo = gr.Blocks( title="Stable Normal Estimation", css=""" .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1, h2, h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """ ) with demo: gr.Markdown(HEADER_MD) with gr.Tabs() as tabs: # Object Tab with gr.Tab("Object"): with gr.Row(): with gr.Column(): object_input = gr.Image(label="Input Object Image", type="filepath") object_sharpness = gr.Slider( minimum=1, maximum=10, value=DEFAULT_SHARPNESS, step=1, label="Sharpness (inference steps)", info="Higher values produce sharper results but take longer" ) with gr.Row(): object_submit_btn = gr.Button("Compute Normal", variant="primary") object_reset_btn = gr.Button("Reset") with gr.Column(): object_output_slider = ImageSlider( label="Normal outputs", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, ) Examples( fn=process_object, examples=sorted([ os.path.join("files", "object", name) for name in os.listdir(os.path.join("files", "object")) if os.path.exists(os.path.join("files", "object")) ]), inputs=[object_input], outputs=[object_output_slider], cache_examples=True, directory_name="examples_object", examples_per_page=50, ) # Scene Tab with gr.Tab("Scene"): with gr.Row(): with gr.Column(): scene_input = gr.Image(label="Input Scene Image", type="filepath") scene_sharpness = gr.Slider( minimum=1, maximum=10, value=DEFAULT_SHARPNESS, step=1, label="Sharpness (inference steps)", info="Higher values produce sharper results but take longer" ) with gr.Row(): scene_submit_btn = gr.Button("Compute Normal", variant="primary") scene_reset_btn = gr.Button("Reset") with gr.Column(): scene_output_slider = ImageSlider( label="Normal outputs", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, ) Examples( fn=process_scene, examples=sorted([ os.path.join("files", "scene", name) for name in os.listdir(os.path.join("files", "scene")) if os.path.exists(os.path.join("files", "scene")) ]), inputs=[scene_input], outputs=[scene_output_slider], cache_examples=True, directory_name="examples_scene", examples_per_page=50, ) # Human Tab with gr.Tab("Human"): with gr.Row(): with gr.Column(): human_input = gr.Image(label="Input Human Image", type="filepath") human_sharpness = gr.Slider( minimum=1, maximum=10, value=DEFAULT_SHARPNESS, step=1, label="Sharpness (inference steps)", info="Higher values produce sharper results but take longer" ) with gr.Row(): human_submit_btn = gr.Button("Compute Normal", variant="primary") human_reset_btn = gr.Button("Reset") with gr.Column(): human_output_slider = ImageSlider( label="Normal outputs", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, ) Examples( fn=process_human, examples=sorted([ os.path.join("files", "human", name) for name in os.listdir(os.path.join("files", "human")) if os.path.exists(os.path.join("files", "human")) ]), inputs=[human_input], outputs=[human_output_slider], cache_examples=True, directory_name="examples_human", examples_per_page=50, ) # Event Handlers for Object Tab object_submit_btn.click( fn=lambda x, _: None if x else gr.Error("Please upload an image"), inputs=[object_input, object_sharpness], outputs=None, queue=False, ).success( fn=process_object, inputs=[object_input, object_sharpness], outputs=[object_output_slider], ) object_reset_btn.click( fn=lambda: (None, DEFAULT_SHARPNESS, None), inputs=[], outputs=[object_input, object_sharpness, object_output_slider], queue=False, ) # Event Handlers for Scene Tab scene_submit_btn.click( fn=lambda x, _: None if x else gr.Error("Please upload an image"), inputs=[scene_input, scene_sharpness], outputs=None, queue=False, ).success( fn=process_scene, inputs=[scene_input, scene_sharpness], outputs=[scene_output_slider], ) scene_reset_btn.click( fn=lambda: (None, DEFAULT_SHARPNESS, None), inputs=[], outputs=[scene_input, scene_sharpness, scene_output_slider], queue=False, ) # Event Handlers for Human Tab human_submit_btn.click( fn=lambda x, _: None if x else gr.Error("Please upload an image"), inputs=[human_input, human_sharpness], outputs=None, queue=False, ).success( fn=process_human, inputs=[human_input, human_sharpness], outputs=[human_output_slider], ) human_reset_btn.click( fn=lambda: (None, DEFAULT_SHARPNESS, None), inputs=[], outputs=[human_input, human_sharpness, human_output_slider], queue=False, ) return demo def main(): demo = create_demo() demo.queue(api_open=False).launch( server_name="0.0.0.0", server_port=7860, ) if __name__ == "__main__": main()