from __future__ import annotations

import functools
import os
import tempfile
import torch
import spaces
import gradio as gr
from PIL import Image
from gradio_imageslider import ImageSlider
from pathlib import Path
from gradio.utils import get_cache_folder

class Examples(gr.helpers.Examples):
    def __init__(self, *args, directory_name=None, **kwargs):
        super().__init__(*args, **kwargs, _initiated_directly=False)
        if directory_name is not None:
            self.cached_folder = get_cache_folder() / directory_name
            self.cached_file = Path(self.cached_folder) / "log.csv"
        self.create()

def load_predictor():
    """Load model predictor using torch.hub"""
    predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
    return predictor

def process_image(
    predictor,
    path_input: str,
    data_type: str = "object"
) -> tuple:
    """Process single image"""
    if path_input is None:
        raise gr.Error("Please upload an image or select one from the gallery.")
        
    name_base = os.path.splitext(os.path.basename(path_input))[0]
    out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal.png")

    # Load and process image
    input_image = Image.open(path_input)
    normal_image = predictor(input_image, match_input_resolution=False, data_type=data_type)
    normal_image.save(out_path)

    yield [input_image, out_path]

def create_demo():
    # Load model
    predictor = load_predictor()
    
    # Create processing functions for each data type
    process_object = spaces.GPU(functools.partial(process_image, predictor, data_type="object"))

    # Define markdown content
    HEADER_MD = """
    # 🎪 StableNormal Turbo

    <p align="center">
    <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
    </a>
    <a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
    </a>
    <a title="Github" href="https://github.com/Stable-X/StableNormal" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://img.shields.io/github/stars/Stable-X/StableNormal?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
    </a>
    <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
    </a>
    </p>
    """

    # Create interface
    demo = gr.Blocks(
        title="Stable Normal Estimation",
        css="""
            .slider .inner { width: 5px; background: #FFF; }
            .viewport { aspect-ratio: 4/3; }
            .tabs button.selected { font-size: 20px !important; color: crimson !important; }
            h1, h2, h3 { text-align: center; display: block; }
            .md_feedback li { margin-bottom: 0px !important; }
        """
    )

    with demo:
        gr.Markdown(HEADER_MD)

        with gr.Tabs() as tabs:
            # Object Tab
            with gr.Tab("Object"):
                with gr.Row():
                    with gr.Column():
                        object_input = gr.Image(label="Input Object Image", type="filepath")
                        with gr.Row():
                            object_submit_btn = gr.Button("Compute Normal", variant="primary")
                            object_reset_btn = gr.Button("Reset")
                    with gr.Column():
                        object_output_slider = ImageSlider(
                            label="Normal outputs",
                            type="filepath",
                            show_download_button=True,
                            show_share_button=True,
                            interactive=False,
                            elem_classes="slider",
                            position=0.25,
                        )

                Examples(
                    fn=process_object,
                    examples=sorted([
                        os.path.join("files", "object", name)
                        for name in os.listdir(os.path.join("files", "object"))
                        if os.path.exists(os.path.join("files", "object"))
                    ]),
                    inputs=[object_input],
                    outputs=[object_output_slider],
                    cache_examples=False,
                    directory_name="examples_object",
                    examples_per_page=50,
                )

        # Event Handlers for Object Tab
        object_submit_btn.click(
            fn=lambda x, _: None if x else gr.Error("Please upload an image"),
            inputs=object_input,
            outputs=None,
            queue=False,
        ).success(
            fn=process_object,
            inputs=object_input,
            outputs=[object_output_slider],
        )

        object_reset_btn.click(
            fn=lambda: (None, DEFAULT_SHARPNESS, None),
            inputs=[],
            outputs=[object_input, object_output_slider],
            queue=False,
        )

    return demo

def main():
    demo = create_demo()
    demo.queue(api_open=False).launch(
        server_name="0.0.0.0",
        server_port=7860,
    )

if __name__ == "__main__":
    main()