from typing import Optional

import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from gradio_image_prompter import ImagePrompter

from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
    MASK_GENERATION_MODE, BOX_PROMPT_MODE
import spaces

MARKDOWN = """
# Segment Anything Model 2 🔥
<div>
    <a href="https://github.com/facebookresearch/segment-anything-2">
        <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;">
    </a>
    <a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
        <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
    </a>
    <a href="https://blog.roboflow.com/what-is-segment-anything-2/">
        <img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
    </a>
    <a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
        <img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
    </a>
</div>

Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable 
visual segmentation in both images and videos. **Video segmentation will be available 
soon.**
"""
EXAMPLES = [
    ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
    ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
    ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None],
]

DEVICE = torch.device('cuda')

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)


@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process(
    checkpoint_dropdown,
    mode_dropdown,
    image_input,
    image_prompter_input
) -> Optional[Image.Image]:
    if mode_dropdown == BOX_PROMPT_MODE:
        image_input = image_prompter_input["image"]
        prompt = image_prompter_input["points"]
        if len(prompt) == 0:
            return image_input

        model = IMAGE_PREDICTORS[checkpoint_dropdown]
        image = np.array(image_input.convert("RGB"))
        box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt])

        model.set_image(image)
        masks, _, _ = model.predict(box=box, multimask_output=False)

        # dirty fix; remove this later
        if len(masks.shape) == 4:
            masks = np.squeeze(masks)

        detections = sv.Detections(
            xyxy=sv.mask_to_xyxy(masks=masks),
            mask=masks.astype(bool)
        )
        return MASK_ANNOTATOR.annotate(image_input, detections)

    if mode_dropdown == MASK_GENERATION_MODE:
        model = MASK_GENERATORS[checkpoint_dropdown]
        image = np.array(image_input.convert("RGB"))
        result = model.generate(image)
        detections = sv.Detections.from_sam(result)
        return MASK_ANNOTATOR.annotate(image_input, detections)


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        checkpoint_dropdown_component = gr.Dropdown(
            choices=CHECKPOINT_NAMES,
            value=CHECKPOINT_NAMES[0],
            label="Checkpoint", info="Select a SAM2 checkpoint to use.",
            interactive=True
        )
        mode_dropdown_component = gr.Dropdown(
            choices=MODE_NAMES,
            value=MODE_NAMES[0],
            label="Mode",
            info="Select a mode to use. `box prompt` if you want to generate masks for "
                 "selected objects, `mask generation` if you want to generate masks "
                 "for the whole image.",
            interactive=True
        )
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(
                type='pil', label='Upload image')
            image_prompter_input_component = ImagePrompter(
                type='pil', label='Image prompt', visible=False)
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            image_output_component = gr.Image(type='pil', label='Image Output')
    with gr.Row():
        gr.Examples(
            fn=process,
            examples=EXAMPLES,
            inputs=[
                checkpoint_dropdown_component,
                mode_dropdown_component,
                image_input_component,
                image_prompter_input_component,
            ],
            outputs=[image_output_component],
            cache_examples=False,
            run_on_click=True
        )


    def on_mode_dropdown_change(text):
        return [
            gr.Image(visible=text == MASK_GENERATION_MODE),
            ImagePrompter(visible=text == BOX_PROMPT_MODE)
        ]

    mode_dropdown_component.change(
        on_mode_dropdown_change,
        inputs=[mode_dropdown_component],
        outputs=[
            image_input_component,
            image_prompter_input_component
        ]
    )
    submit_button_component.click(
        fn=process,
        inputs=[
            checkpoint_dropdown_component,
            mode_dropdown_component,
            image_input_component,
            image_prompter_input_component,
        ],
        outputs=[image_output_component]
    )

demo.launch(debug=False, show_error=True)