import os
import time
import shutil
from pathlib import Path
from typing import Union
import atexit
import spaces
from concurrent.futures import ThreadPoolExecutor
import trimesh

import gradio as gr
from gradio_imageslider import ImageSlider
import cv2
import numpy as np
import imageio
from promptda.promptda import PromptDA
from promptda.utils.io_wrapper import load_image, load_depth
from promptda.utils.depth_utils import visualize_depth, unproject_depth
DEVICE = 'cuda' 
# if torch.cuda.is_available(
# ) else 'mps' if torch.backends.mps.is_available() else 'cpu'
model = PromptDA.from_pretrained('depth-anything/promptda_vitl').to(DEVICE).eval()
# model = PromptDA.from_pretrained('depth-anything/promptda_vitl').eval()
thread_pool_executor = ThreadPoolExecutor(max_workers=1)

def delete_later(path: Union[str, os.PathLike], delay: int = 300):
    print(f"Deleting file: {path}")
    def _delete():
        try: 
            if os.path.isfile(path):
                os.remove(path) 
                print(f"Deleted file: {path}")
            elif os.path.isdir(path):
                shutil.rmtree(path)
                print(f"Deleted directory: {path}")
        except: 
            pass
    def _wait_and_delete():
        time.sleep(delay)
        _delete(path)
    thread_pool_executor.submit(_wait_and_delete)
    atexit.register(_delete)


@spaces.GPU
def run_with_gpu(image, prompt_depth):
    image = image.to(DEVICE)
    prompt_depth = prompt_depth.to(DEVICE)
    depth = model.predict(image, prompt_depth)
    depth = depth[0, 0].detach().cpu().numpy()
    return depth

def check_is_stray_scanner_app_capture(input_dir):
    assert os.path.exists(os.path.join(input_dir, 'rgb.mp4')), 'rgb.mp4 not found'
    pass

# @spaces.GPU
def run(input_file, resolution):
    # unzip zip file
    input_file = input_file.name
    root_dir = os.path.dirname(input_file)
    scene_name = input_file.split('/')[-1].split('.')[0]
    input_dir = os.path.join(root_dir, scene_name)
    cmd = f'unzip -o {input_file} -d {root_dir}'
    os.system(cmd)
    check_is_stray_scanner_app_capture(input_dir)

    # extract rgb images
    os.makedirs(os.path.join(input_dir, 'rgb'), exist_ok=True)
    cmd = f'ffmpeg -i {input_dir}/rgb.mp4 -start_number 0 -frames:v 10 -q:v 2 {input_dir}/rgb/%06d.jpg'
    os.system(cmd)

    # Loading & Inference
    image_path = os.path.join(input_dir, 'rgb', '000000.jpg')
    image = load_image(image_path)
    prompt_depth_path = os.path.join(input_dir, 'depth/000000.png')
    prompt_depth = load_depth(prompt_depth_path)
    depth = run_with_gpu(image, prompt_depth)


    color = (image[0].permute(1,2,0).cpu().numpy() * 255.).astype(np.uint8)

    # Visualization file
    vis_depth, depth_min, depth_max = visualize_depth(depth, ret_minmax=True)
    vis_prompt_depth = visualize_depth(prompt_depth[0, 0].detach().cpu().numpy(), depth_min=depth_min, depth_max=depth_max)
    vis_prompt_depth = cv2.resize(vis_prompt_depth, (vis_depth.shape[1], vis_depth.shape[0]), interpolation=cv2.INTER_NEAREST)
    # Add text to vis_prompt_depth
    text_x = vis_prompt_depth.shape[1] - 250 + 15
    text_y = vis_prompt_depth.shape[0] - 45 + 27
    vis_prompt_depth = cv2.rectangle(vis_prompt_depth, 
                                     (vis_prompt_depth.shape[1] - 250, vis_prompt_depth.shape[0] - 45), 
                                     (vis_prompt_depth.shape[1] - 5, vis_prompt_depth.shape[0] - 5), 
                                     (70, 70, 70), -1)
    vis_prompt_depth = cv2.putText(vis_prompt_depth, 'Prompt depth', 
                                   (text_x, text_y), 
                                   cv2.FONT_HERSHEY_SIMPLEX, 
                                   1, (255, 255, 255), 2, cv2.LINE_AA)

    text_x = 5 + 15
    text_y = vis_depth.shape[0] - 45 + 27
    vis_depth = cv2.rectangle(vis_depth, 
                              (5, vis_depth.shape[0] - 45), 
                              (250, vis_depth.shape[0] - 5), 
                              (70, 70, 70), -1)
    vis_depth = cv2.putText(vis_depth, 'Output depth', 
                            (text_x, text_y), 
                            cv2.FONT_HERSHEY_SIMPLEX, 
                            1, (255, 255, 255), 2, cv2.LINE_AA)

    # PLY File
    ixt_path = os.path.join(input_dir, f'camera_matrix.csv')
    ixt = np.loadtxt(ixt_path, delimiter=',')
    orig_max = 1920
    now_max = max(color.shape[1], color.shape[0])
    scale = orig_max / now_max
    ixt[:2] = ixt[:2] / scale
    points, colors = unproject_depth(depth, ixt=ixt, color=color, ret_pcd=False)
    pcd = trimesh.PointCloud(vertices=points, colors=colors)
    ply_path = os.path.join(input_dir, f'pointcloud.ply')
    pcd.export(ply_path)
    # o3d.io.write_point_cloud(ply_path, pcd)

    glb_path = os.path.join(input_dir, f'pointcloud.glb')
    scene_3d = trimesh.Scene()
    glb_colors = np.asarray(colors).astype(np.float32)
    glb_colors = np.concatenate([glb_colors, np.ones_like(glb_colors[:, :1])], axis=1)
    # glb_colors = (np.asarray(pcd.colors) * 255).astype(np.uint8)
    pcd_data = trimesh.PointCloud(
        vertices=np.asarray(points) * np.array([[1, -1, -1]]),
        colors=glb_colors.astype(np.float64),
    )
    scene_3d.add_geometry(pcd_data)
    scene_3d.export(file_obj=glb_path)
    # o3d.io.write_point_cloud(glb_path, pcd)

    # Depth Map Original Value
    depth_path = os.path.join(input_dir, f'depth.png')
    output_depth = (depth * 1000).astype(np.uint16)
    imageio.imwrite(depth_path, output_depth)


    delete_later(Path(input_dir))
    delete_later(Path(input_file))

    return color, (vis_depth, vis_prompt_depth), Path(glb_path), Path(ply_path).as_posix(), Path(depth_path).as_posix()

DESCRIPTION = """
# Estimate accurate and high-resolution depth maps from your iPhone capture.

Project Page: [Prompt Depth Anything](https://promptda.github.io/)

## Requirements:
1. iPhone 12 Pro or later Pro models, iPad 2020 Pro or later Pro models.
2. Free iOS App: [Stray Scanner App](https://apps.apple.com/us/app/stray-scanner/id1557051662).

## Testing Steps:
1. Capture a scene with the Stray Scanner App. Use the iPhone [Files App](https://apps.apple.com/us/app/files/id1232058109) to compress it into a zip file and transfer it to your computer. [Example screen recording.](https://haotongl.github.io/promptda/assets/ScreenRecording_12-16-2024.mp4).
2. Upload the zip file and click "Submit" to get the depth map of the first frame.

Note:
- Currently, this demo only supports inference for the first frame. If you need to obtain all depth frames, please refer to our [GitHub repo](https://github.com/DepthAnything/PromptDA).
- The depth map is stored as uint16, with a unit of millimeters.
- **You can refer to the bottom of this page for an example demo.**
"""

def main():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            input_file = gr.File(type="filepath", label="Stray scanner app capture zip file")
            resolution = gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution")
            submit_btn = gr.Button("Submit")
        
        # gr.Examples(examples=[
        #         ["data/assets/example0_chair.zip", "756x1008"]
        #     ],
        #     inputs=[input_file, resolution],
        #     label="Examples",
        # ) 

        with gr.Row():
            with gr.Column():
                output_rgb = gr.Image(type="numpy", label="RGB Image")
            with gr.Column():
                output_depths = ImageSlider(label="Output depth / prompt depth", position=0.5)
        
        with gr.Row():
            with gr.Column():
                output_3d_model = gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0])
            with gr.Column():
                output_ply = gr.File(type="filepath", label="Download the unprojected point cloud as .ply file", height=30)
                output_depth_map = gr.File(type="filepath", label="Download the depth map as .png file", height=30)
        outputs = [
            output_rgb,
            output_depths,
            output_3d_model,
            output_ply,
            output_depth_map,
        ]
        gr.Examples(examples=[
                ["data/assets/example0_chair.zip", "756x1008"]
            ],
            fn=run,
            inputs=[input_file, resolution],
            outputs=outputs,
            label="Examples",
            cache_examples=True,
        ) 
        submit_btn.click(run, 
                         inputs=[input_file, resolution], 
                         outputs=outputs)

    demo.launch(share=True)
# def main():
#     gr.Interface(
#         fn=run,
#         inputs=[
#             gr.File(type="filepath", label="Stray scanner app capture zip file"),
#             gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution")
#         ],
#         outputs=[
#             gr.Image(type="numpy", label="RGB Image"),
#             ImageSlider(label="Depth map / prompt depth", position=0.5),
#             gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0]),
#             gr.File(type="filepath", label="Download the unprojected point cloud as .ply file"),
#             gr.File(type="filepath", label="Download the depth map as .png file"),
#         ],
#         title=None,
#         description=DESCRIPTION,
#         clear_btn=None,
#         allow_flagging="never",
#         theme=gr.themes.Soft(),
#         examples=[
#             ["data/assets/example0_chair.zip"]
#         ]
#     ).launch()
main()