import collections
import json
import logging
import os
import threading
import uuid
import leaderboard

import datasets
import gradio as gr
import pandas as pd

from io_utils import (
    get_yaml_path,
    read_column_mapping,
    save_job_to_pipe,
    write_column_mapping,
    write_log_to_user_file,
)
from text_classification import (
    check_model_task,
    get_example_prediction,
    get_labels_and_features_from_dataset,
)
from wordings import (
    CHECK_CONFIG_OR_SPLIT_RAW,
    CONFIRM_MAPPING_DETAILS_FAIL_RAW,
    MAPPING_STYLED_ERROR_WARNING,
    get_styled_input,
)

MAX_LABELS = 40
MAX_FEATURES = 20

HF_REPO_ID = "HF_REPO_ID"
HF_SPACE_ID = "SPACE_ID"
HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
HF_GSK_HUB_URL = "GSK_HUB_URL"
HF_GSK_HUB_PROJECT_KEY = "GSK_HUB_PROJECT_KEY"
HF_GSK_HUB_KEY = "GSK_API_KEY"
HF_GSK_HUB_HF_TOKEN = "GSK_HF_TOKEN"
HF_GSK_HUB_UNLOCK_TOKEN = "GSK_HUB_UNLOCK_TOKEN"

LEADERBOARD = "giskard-bot/evaluator-leaderboard"

global ds_dict, ds_config
ds_dict = None
ds_config = None

def get_related_datasets_from_leaderboard(model_id):
    records = leaderboard.records
    model_records = records[records["model_id"] == model_id]
    datasets_unique = list(model_records["dataset_id"].unique())

    if len(datasets_unique) == 0:
        all_unique_datasets = list(records["dataset_id"].unique())
        return gr.update(choices=all_unique_datasets, value="")
    
    return gr.update(choices=datasets_unique, value=datasets_unique[0])


logger = logging.getLogger(__file__)


def check_dataset(dataset_id):
    logger.info(f"Loading {dataset_id}")
    try:
        configs = datasets.get_dataset_config_names(dataset_id)
        if len(configs) == 0:
            return (
                gr.update(),
                gr.update(),
                ""
            )
        splits = list(
                    datasets.load_dataset(
                        dataset_id, configs[0]
                    ).keys()
                )
        return (
            gr.update(choices=configs, value=configs[0], visible=True),
            gr.update(choices=splits, value=splits[0], visible=True),
            ""
        )
    except Exception as e:
        logger.warn(f"Check your dataset {dataset_id}: {e}")
        return (
            gr.update(),
            gr.update(),
            ""
        )



def write_column_mapping_to_config(uid, *labels):
    # TODO: Substitute 'text' with more features for zero-shot
    # we are not using ds features because we only support "text" for now
    all_mappings = read_column_mapping(uid)

    if labels is None:
        return
    all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
    all_mappings = export_mappings(
        all_mappings,
        "features",
        ["text"],
        labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)],
    )

    write_column_mapping(all_mappings, uid)


def export_mappings(all_mappings, key, subkeys, values):
    if key not in all_mappings.keys():
        all_mappings[key] = dict()
    if subkeys is None:
        subkeys = list(all_mappings[key].keys())

    if not subkeys:
        logging.debug(f"subkeys is empty for {key}")
        return all_mappings

    for i, subkey in enumerate(subkeys):
        if subkey:
            all_mappings[key][subkey] = values[i % len(values)]
    return all_mappings


def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid):
    all_mappings = read_column_mapping(uid)
    # For flattened raw datasets with no labels
    # check if there are shared labels between model and dataset
    shared_labels = set(model_labels).intersection(set(ds_labels))
    if shared_labels:
        ds_labels = list(shared_labels)
    if len(ds_labels) > MAX_LABELS:
        ds_labels = ds_labels[:MAX_LABELS]
        gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")

    ds_labels.sort()
    model_labels.sort()

    lables = [
        gr.Dropdown(
            label=f"{label}",
            choices=model_labels,
            value=model_labels[i % len(model_labels)],
            interactive=True,
            visible=True,
        )
        for i, label in enumerate(ds_labels)
    ]
    lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
    all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)

    # TODO: Substitute 'text' with more features for zero-shot
    features = [
        gr.Dropdown(
            label=f"{feature}",
            choices=ds_features,
            value=ds_features[0],
            interactive=True,
            visible=True,
        )
        for feature in ["text"]
    ]
    features += [
        gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
    ]
    all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
    write_column_mapping(all_mappings, uid)

    return lables + features


def precheck_model_ds_enable_example_btn(
    model_id, dataset_id, dataset_config, dataset_split
):
    model_task = check_model_task(model_id)
    if model_task is None or model_task != "text-classification":
        gr.Warning("Please check your model.")
        return gr.update(interactive=False), ""

    if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
        return (gr.update(), gr.update(), "")
    
    try:
        ds = datasets.load_dataset(dataset_id, dataset_config)
        df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
        ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])

        if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
            gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
            return (gr.update(interactive=False), gr.update(value=df, visible=True), "")

        return (gr.update(interactive=True), gr.update(value=df, visible=True), "")
    except Exception as e:
        # Config or split wrong
        gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
        return (gr.update(interactive=False), gr.update(value=pd.DataFrame(), visible=False), "")




def align_columns_and_show_prediction(
    model_id, dataset_id, dataset_config, dataset_split, uid, run_inference, inference_token
):
    model_task = check_model_task(model_id)
    if model_task is None or model_task != "text-classification":
        gr.Warning("Please check your model.")
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            "",
            *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
        )

    dropdown_placement = [
        gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
    ]

    prediction_input, prediction_output = get_example_prediction(
        model_id, dataset_id, dataset_config, dataset_split
    )

    model_labels = list(prediction_output.keys())
    
    ds = datasets.load_dataset(dataset_id, dataset_config)[dataset_split]
    ds_labels, ds_features = get_labels_and_features_from_dataset(ds)

    # when dataset does not have labels or features
    if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
        gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            "",
            *dropdown_placement,
        )

    column_mappings = list_labels_and_features_from_dataset(
        ds_labels,
        ds_features,
        model_labels,
        uid,
    )

    # when labels or features are not aligned
    # show manually column mapping
    if (
        collections.Counter(model_labels) != collections.Counter(ds_labels)
        or ds_features[0] != "text"
    ):
        return (
            gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
            gr.update(visible=False),
            gr.update(visible=True, open=True),
            gr.update(interactive=(run_inference and inference_token != "")),
            "",
            *column_mappings,
        )

    return (
        gr.update(value=get_styled_input(prediction_input), visible=True),
        gr.update(value=prediction_output, visible=True),
        gr.update(visible=True, open=False),
        gr.update(interactive=(run_inference and inference_token != "")),
        "",
        *column_mappings,
    )


def check_column_mapping_keys_validity(all_mappings):
    if all_mappings is None:
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (gr.update(interactive=True), gr.update(visible=False))

    if "labels" not in all_mappings.keys():
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (gr.update(interactive=True), gr.update(visible=False))


def construct_label_and_feature_mapping(all_mappings):
    label_mapping = {}
    for i, label in zip(
        range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
    ):
        label_mapping.update({str(i): label})

    if "features" not in all_mappings.keys():
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (gr.update(interactive=True), gr.update(visible=False))
    feature_mapping = all_mappings["features"]
    return label_mapping, feature_mapping


def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
    all_mappings = read_column_mapping(uid)
    check_column_mapping_keys_validity(all_mappings)
    label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings)

    leaderboard_dataset = None
    if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator":
        leaderboard_dataset = LEADERBOARD

    if inference:
        inference_type = "hf_inference_api"


    # TODO: Set column mapping for some dataset such as `amazon_polarity`
    command = [
        "giskard_scanner",
        "--loader",
        "huggingface",
        "--model",
        m_id,
        "--dataset",
        d_id,
        "--dataset_config",
        config,
        "--dataset_split",
        split,
        "--output_format",
        "markdown",
        "--output_portal",
        "huggingface",
        "--feature_mapping",
        json.dumps(feature_mapping),
        "--label_mapping",
        json.dumps(label_mapping),
        "--scan_config",
        get_yaml_path(uid),
        "--inference_type",
        inference_type,
        "--inference_api_token",
        inference_token,
    ]

    # The token to publish post
    if os.environ.get(HF_WRITE_TOKEN):
        command.append("--hf_token")
        command.append(os.environ.get(HF_WRITE_TOKEN))

    # The repo to publish post
    if os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID):
        command.append("--discussion_repo")
        # TODO: Replace by the model id
        command.append(os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID))

    # The repo to publish for ranking
    if leaderboard_dataset:
        command.append("--leaderboard_dataset")
        command.append(leaderboard_dataset)

    # The info to upload to Giskard hub
    if os.environ.get(HF_GSK_HUB_KEY):
        command.append("--giskard_hub_api_key")
        command.append(os.environ.get(HF_GSK_HUB_KEY))
        if os.environ.get(HF_GSK_HUB_URL):
            command.append("--giskard_hub_url")
            command.append(os.environ.get(HF_GSK_HUB_URL))
        if os.environ.get(HF_GSK_HUB_PROJECT_KEY):
            command.append("--giskard_hub_project_key")
            command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY))
        if os.environ.get(HF_GSK_HUB_HF_TOKEN):
            command.append("--giskard_hub_hf_token")
            command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN))
        if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN):
            command.append("--giskard_hub_unlock_token")
            command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN))

    eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
    logging.info(f"Start local evaluation on {eval_str}")
    save_job_to_pipe(uid, command, eval_str, threading.Lock())

    write_log_to_user_file(
        uid,
        f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
    )
    gr.Info(f"Start local evaluation on {eval_str}")

    return (
        gr.update(interactive=False),  # Submit button
        gr.update(lines=5, visible=True, interactive=False),
        uuid.uuid4(),  # Allocate a new uuid
    )