"""Interface for labeling concepts in images.
"""

from typing import Optional

import gradio as gr

from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME

def get_image(
    step: int,
    split: str,
    index: str,
    filtered_indices: dict,
    profile: gr.OAuthProfile
):
    username = profile.username
    try:
        int_index = int(index)
    except:
        gr.Warning("Error parsing index using 0")
        int_index = 0
    sample_idx = int_index + step
    if sample_idx < 0:
        gr.Warning("No previous image.")
        sample_idx = 0
    if sample_idx >= len(global_variables.all_metadata[split]):
        gr.Warning("No next image.")
        sample_idx = len(global_variables.all_metadata[split]) - 1
    sample = global_variables.all_metadata[split][sample_idx]
    image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
    try:
        username_votes = global_variables.all_votes[sample["id"]][username]
        voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
        unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
    except KeyError:
        voted_concepts = []
        unseen_concepts = []
    tie_concepts = [c for c in CONCEPTS if sample[c] is None]

    return (
        image_path, 
        voted_concepts, 
        f"{split}:{sample_idx}", 
        sample["class"], 
        {c: sample[c] for c in CONCEPTS},
        str(sample_idx),
        unseen_concepts,
        tie_concepts,
        filtered_indices, 
    )

def make_get_image(step):
    def f(
        split: str,
        index: str,
        filtered_indices: dict,
        profile: gr.OAuthProfile 
    ):
        return get_image(step, split, index, filtered_indices, profile)
    return f

get_next_image = make_get_image(1)
get_prev_image = make_get_image(-1)
get_current_image = make_get_image(0)

def submit_label(
    voted_concepts: list,
    current_image: Optional[str],
    split, 
    index,
    filtered_indices, 
    profile: gr.OAuthProfile
):
    username = profile.username
    if current_image is None:
        gr.Warning("No image selected.")
        return None, None, None, None, None, None, None, index, filtered_indices
    
    global_variables.update_votes(username, current_image, voted_concepts)

    gr.Info("Submit success")
    return get_next_image(
        split, 
        index,
        filtered_indices,
        profile
    )
    
def save_current_work(
    profile: gr.OAuthProfile,
):
    username = profile.username
    global_variables.save_current_work(username)
    gr.Info("Save success")

with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            with gr.Group():
                gr.Markdown(
                    "## # Image Selection",
                )
                split = gr.Radio(
                    label="Split",
                    choices=["train", "validation", "test"],
                    value="train",
                )
                index = gr.Textbox(
                    value="0",
                    label="Index",
                    max_lines=1,
                )
            with gr.Group():
                voted_concepts = gr.CheckboxGroup(
                    label="Voted Concepts",
                    choices=CONCEPTS,
                )
                unseen_concepts = gr.CheckboxGroup(
                    label="Previously Unseen Concepts",
                    choices=CONCEPTS,
                )
                tie_concepts = gr.CheckboxGroup(
                    label="Tie Concepts",
                    choices=CONCEPTS,
                )

            with gr.Row():
                prev_button = gr.Button(
                    value="Prev",
                )
                next_button = gr.Button(
                    value="Next",
                )
                gr.LoginButton()
                submit_button = gr.Button(
                    value="Local Submit",
                )
            with gr.Row():
                save_button = gr.Button(
                    value="Save",
                )
            with gr.Group():
                gr.Markdown(
                    "##  # Image Info",
                )
                im_class = gr.Textbox(
                    label="Class",
                )
                im_concepts = gr.JSON(
                    label="Concepts",
                )
        with gr.Column():
            image = gr.Image(
                label="Image",
            )
    current_image = gr.State(None)
    filtered_indices = gr.State({
        split: list(range(len(global_variables.all_metadata[split])))
        for split in global_variables.all_metadata
    })
    common_output = [
        image, 
        voted_concepts, 
        current_image, 
        im_class,
        im_concepts,
        index,
        unseen_concepts,
        tie_concepts,
        filtered_indices, 
    ]
    common_input = [split, index, filtered_indices]
    prev_button.click(
        get_prev_image, 
        inputs=common_input,
        outputs=common_output
    )
    next_button.click(
        get_next_image, 
        inputs=common_input,
        outputs=common_output
    )
    submit_button.click(
        submit_label,
        inputs=[voted_concepts, current_image, split, index, filtered_indices],
        outputs=common_output
    )
    index.submit(
        get_current_image,
        inputs=common_input,
        outputs=common_output,
    )

    save_button.click(
        save_current_work,
        outputs=[image]
    )