import gradio as gr
import imageio
import numpy as np

from demo.img_gen import img_gen
from demo.mesh_recon import mesh_reconstruction
from demo.relighting_gen import relighting_gen
from demo.render_hints import render_hint_images_btn_func
from demo.rm_bg import rm_bg


with gr.Blocks(title="DiLightNet Demo") as demo:
    gr.Markdown("# DiLightNet: Fine-grained Lighting Control for Image Diffusion")

    with gr.Row():
        # 1. Reference Image Input / Generation
        with gr.Column(variant="panel"):
            gr.Markdown("## Step 1. Input or Generate Reference Image")
            input_image = gr.Image(height=512, width=512, label="Input Image", interactive=True)
            with gr.Accordion("Generate Image", open=False):
                with gr.Group():
                    prompt = gr.Textbox(value="", label="Prompt", lines=3, placeholder="Input prompt here")
                    with gr.Row():
                        seed = gr.Number(value=42, label="Seed", interactive=True)
                        steps = gr.Number(value=20, label="Steps", interactive=True)
                        cfg = gr.Number(value=7.5, label="CFG", interactive=True)
                        down_from_768 = gr.Checkbox(label="Downsample from 768", value=True)
                with gr.Row():
                    generate_btn = gr.Button(value="Generate")
                    generate_btn.click(fn=img_gen, inputs=[prompt, seed, steps, cfg, down_from_768], outputs=[input_image])

        # 2. Background Removal
        with gr.Column(variant="panel"):
            gr.Markdown("## Step 2. Remove Background")
            with gr.Tab("Masked Image"):
                masked_image = gr.Image(height=512, width=512, label="Masked Image", interactive=True)
            with gr.Tab("Mask"):
                mask = gr.Image(height=512, width=512, label="Mask", interactive=False)
            use_sam = gr.Checkbox(label="Use SAM for Refinement", value=False)
            rm_bg_btn = gr.Button(value="Remove Background")
            rm_bg_btn.click(fn=rm_bg, inputs=[input_image, use_sam], outputs=[masked_image, mask])

        # 3. Depth Estimation & Mesh Reconstruction
        with gr.Column(variant="panel"):
            gr.Markdown("## Step 3. Depth Estimation & Mesh Reconstruction")
            mesh = gr.Model3D(label="Mesh Reconstruction", clear_color=(1.0, 1.0, 1.0, 1.0), interactive=True)
            with gr.Column():
                with gr.Accordion("Options", open=False):
                    with gr.Group():
                        remove_edges = gr.Checkbox(label="Remove Occlusion Edges", value=False)
                        fov = gr.Number(value=55., label="FOV", interactive=True)
                        mask_threshold = gr.Slider(value=25., label="Mask Threshold", minimum=0., maximum=255., step=1.)
                depth_estimation_btn = gr.Button(value="Estimate Depth")
                depth_estimation_btn.click(
                    fn=mesh_reconstruction,
                    inputs=[masked_image, mask, remove_edges, fov, mask_threshold],
                    outputs=[mesh]
                )

    gr.Markdown("## Step 4. Render Hints")
    with gr.Row():
        with gr.Column():
            hint_image = gr.Image(label="Hint Image")
        with gr.Column():
            pl_pos_x = gr.Slider(value=3., label="Point Light X", minimum=-5., maximum=5., step=0.01)
            pl_pos_y = gr.Slider(value=1., label="Point Light Y", minimum=-5., maximum=5., step=0.01)
            pl_pos_z = gr.Slider(value=3., label="Point Light Z", minimum=-5., maximum=5., step=0.01)
            power = gr.Slider(value=1000., label="Point Light Power", minimum=0., maximum=2000., step=1.)
            render_btn = gr.Button(value="Render Hints")
            res_folder_path = gr.Textbox("", visible=False)

            def render_wrapper(mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power,
                               progress=gr.Progress(track_tqdm=True)):
                res_path = render_hint_images_btn_func(mesh, fov, [(pl_pos_x, pl_pos_y, pl_pos_z)], power)
                hint_files = [res_path + '/hint00' + mat for mat in ["_diffuse.png", "_ggx0.34.png"]]
                hints = []
                for hint_file in hint_files:
                    hint = imageio.v3.imread(hint_file)
                    hints.append(hint)
                hints = np.concatenate(hints, axis=1)
                return hints, res_path

            render_btn.click(
                fn=render_wrapper,
                inputs=[mesh, fov, pl_pos_x, pl_pos_y, pl_pos_z, power],
                outputs=[hint_image, res_folder_path]
            )

    gr.Markdown("## Step 5. Relighting!")
    with gr.Row():
        res_image = gr.Image(label="Result Image")
        with gr.Column():
            with gr.Group():
                relighting_prompt = gr.Textbox(value="", label="Relighting Text Prompt", lines=3,
                                               placeholder="Input prompt here",
                                               interactive=True)
                reuse_btn = gr.Button(value="Reuse Image Generation Prompt")
                reuse_btn.click(fn=lambda x: x, inputs=[prompt], outputs=[relighting_prompt])
                with gr.Accordion("Options", open=False):
                    with gr.Row():
                        relighting_seed = gr.Number(value=3407, label="Seed", interactive=True)
                        relighting_steps = gr.Number(value=20, label="Steps", interactive=True)
                        relighting_cfg = gr.Number(value=3.0, label="CFG", interactive=True)
            with gr.Row():
                relighting_generate_btn = gr.Button(value="Generate")

            def gen_relighting_image(masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
                                     relighting_steps, relighting_cfg,
                                     progress=gr.Progress(track_tqdm=True)):
                relighting_gen(
                    masked_ref_img=masked_image,
                    mask=mask,
                    cond_path=res_folder_path,
                    frames=1,
                    prompt=relighting_prompt,
                    steps=int(relighting_steps),
                    seed=int(relighting_seed),
                    cfg=relighting_cfg
                )
                res = imageio.v3.imread(res_folder_path + '/relighting00.png')
                return res


            relighting_generate_btn.click(fn=gen_relighting_image,
                                          inputs=[masked_image, mask, res_folder_path, relighting_prompt, relighting_seed,
                                                  relighting_steps, relighting_cfg],
                                          outputs=[res_image])


if __name__ == '__main__':
    demo.queue().launch(server_name="0.0.0.0", share=True)