import csv
import json
import multiprocessing as mp
import os
from typing import Any, Dict, List, NewType, Optional, Union

import numpy as np
import yaml
from datasets import Dataset, DatasetDict, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from tqdm import tqdm
from yaml.loader import SafeLoader

LANGUAGE_TO_SUFFIX = {
    "chinese_simplified": "zh-CN",
    "french": "fr",
    "portuguese": "pt",
    "english": "en",
    "arabic": "ar",
    "hindi": "hi",
    "indonesian": "id",
    "amharic": "am",
    "bengali": "bn",
    "burmese": "my",
    "chinese": "zh-CN",
    "swahili": "sw",
    "bulgarian": "bg",
    "thai": "th",
    "urdu": "ur",
    "turkish": "tr",
    "spanish": "es",
    "chinese": "zh",
    "greek": "el",
    "german": "de",
}

NUMBER_TO_TAG = {0: "entailment", 1: "neutral", 2: "contradiction"}

PARAMS = NewType("PARAMS", Dict[str, Any])


def read_parameters(args_path) -> PARAMS:
    with open(args_path) as f:
        args = yaml.load(f, Loader=SafeLoader)
    return args


def get_key(key_path):
    with open(key_path) as f:
        key = f.read().split("\n")[0]
    return key


def _translate_example(
    example: Dict[str, str], src_language: str, target_language: str
):
    translator = EasyGoogleTranslate(
        source_language=LANGUAGE_TO_SUFFIX[src_language],
        target_language=LANGUAGE_TO_SUFFIX[target_language],
        timeout=30,
    )
    try:
        return {
            "premise": translator.translate(example["premise"]),
            "hypothesis": translator.translate(example["hypothesis"]),
            "label": "",
        }
    except Exception as e:
        print(e)


def choose_few_shot_examples(
    train_dataset: Dataset,
    few_shot_size: int,
    context: List[str],
    selection_criteria: str,
    lang: str,
) -> List[Dict[str, Union[str, int]]]:
    """Selects few-shot examples from training datasets

    Args:
        train_dataset (Dataset): Training Dataset
        few_shot_size (int): Number of few-shot examples
        selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]

    Returns:
        List[Dict[str, Union[str, int]]]: Selected examples
    """
    selected_examples = []

    example_idxs = []
    if selection_criteria == "first_k":
        example_idxs = list(range(few_shot_size))
    elif selection_criteria == "random":
        example_idxs = (
            np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
            .astype(int)
            .tolist()
        )

    ic_examples = [train_dataset[idx] for idx in example_idxs]

    ic_examples = [
        {
            "premise": example["premise"],
            "hypothesis": example["hypothesis"],
            "label": NUMBER_TO_TAG[example["label"]],
        }
        for example in ic_examples
    ]

    for idx, ic_language in enumerate(context):
        (
            selected_examples.append(ic_examples[idx])
            if ic_language == lang
            else (
                selected_examples.append(
                    _translate_example(
                        example=ic_examples[idx],
                        src_language=lang,
                        target_language=ic_language,
                    )
                )
            )
        )

    return selected_examples


def load_xnli_dataset(
    dataset_name: str,
    lang: str,
    split: str,
    limit: int = 200,
) -> Union[Dataset, DatasetDict]:
    """
    Args:
        lang (str): Language for which xnli dataset is to be loaded
        split (str): Train test of validation split of the model to load
        dataset_frac (float): Fraction of examples to load. Defaults to 1.0

    Returns:
        Union[Dataset, DatasetDict]: huggingface dataset object
    """
    if dataset_name == "indicxnli":  ##PJ:To add except hindi
        dataset = load_dataset("Divyanshu/indicxnli", LANGUAGE_TO_SUFFIX[lang])[split]
    else:
        dataset = load_dataset("xnli", LANGUAGE_TO_SUFFIX[lang])[split]
    return dataset.select(np.arange(limit))


def construct_prompt(
    instruction: str, test_example: dict, ic_examples: List[dict], zero_shot: bool
):
    example_prompt = PromptTemplate(
        input_variables=["premise", "hypothesis", "label"],
        template="Premise: {premise}\n Hypothesis: {hypothesis} \n Label{label}",
    )

    zero_shot_template = (
        f"""{instruction}""" + "\n hypothesis: {hypothesis} + \n  Premise: {premise}" ""
    )

    prompt = (
        FewShotPromptTemplate(
            examples=ic_examples,
            prefix=instruction,
            example_prompt=example_prompt,
            suffix="Premise: {premise} \n Hypothesis: {hypothesis}",
            input_variables=["hypothesis", "premise"],
        )
        if not zero_shot
        else PromptTemplate(
            input_variables=["hypothesis", "premise"], template=zero_shot_template
        )
    )

    return (
        prompt.format(
            hypothesis=test_example["hypothesis"], premise=test_example["premise"]
        ),
        test_example["label"],
    )


def dump_metrics(
    lang: str,
    config: Dict[str, str],
    r1: float,
    r2: float,
    rL: float,
    metric_logger_path: str,
):
    # Check if the metric logger file exists
    file_exists = os.path.exists(metric_logger_path)

    # Open the CSV file in append mode
    with open(metric_logger_path, "a", newline="") as f:
        csvwriter = csv.writer(f, delimiter=",")

        # Write header row if the file is newly created
        if not file_exists:
            header = [
                "Language",
                "Prefix",
                "Input",
                "Context",
                "Output",
                "R1",
                "R2",
                "RL",
            ]
            csvwriter.writerow(header)

        csvwriter.writerow(
            [
                lang,
                config["prefix"],
                config["input"],
                config["context"][0],
                config["output"],
                r1,
                r2,
                rL,
            ]
        )


def dump_predictions(idx, response, label, response_logger_file):
    obj = {"q_idx": idx, "prediction": response, "label": label}
    with open(response_logger_file, "a") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")


def compute_rouge(scorer, pred, label):
    score = scorer.score(pred, label)
    return score["rouge1"], score["rouge2"], score["rougeL"]


def _translate_instruction(basic_instruction: str, target_language: str) -> str:
    translator = EasyGoogleTranslate(
        source_language="en",
        target_language=LANGUAGE_TO_SUFFIX[target_language],
        timeout=10,
    )
    return translator.translate(basic_instruction)


def _translate_prediction_to_output_language(
    prediction: str, prediction_language: str, output_language: str
) -> str:
    translator = EasyGoogleTranslate(
        source_language=LANGUAGE_TO_SUFFIX[prediction_language],
        target_language=LANGUAGE_TO_SUFFIX[output_language],
        timeout=10,
    )
    return translator.translate(prediction)


def create_instruction(lang: str):
    basic_instruction = f"""
        You are an NLP assistant whose purpose is to solve Natural Language Inference (NLI) problems.
        NLI is the task of determining the inference relation between two texts: entailment,
        contradiction, or neutral. 
        Your answer should be one word of the following - entailment, contradiction, or neutral. 
        Pay attention: The output should be only one word!!!!
        """
    return (
        basic_instruction
        if lang == "english"
        else _translate_instruction(basic_instruction, target_language=lang)
    )


def run_one_configuration(params: Optional[PARAMS] = None, zero: bool = False):
    if not params:
        params = read_parameters("../../parameters.yaml")

    lang = params["selected_language"]
    config = params["config"]
    zero_shot = len(config["context"]) == 0

    if not zero:
        config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}"
    else:
        config_header = f"{config['input']}_{config['prefix']}_zero"
    test_data = load_xnli_dataset(
        dataset_name=params["dataset_name"],
        lang=lang,
        split="test",
        limit=params["limit"],
    )

    pool = mp.Pool(processes=3)

    # Iterate over test_data using tqdm for progress tracking
    for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)):
        # Apply asynchronous processing of each test example
        pool.apply_async(
            process_test_example,
            args=(
                test_data,
                config_header,
                idx,
                test_example,
                config,
                zero_shot,
                lang,
                params,
            ),
        )

    # Close the pool and wait for all processes to finish
    pool.close()
    pool.join()


def process_test_example(
    test_data, config_header, idx, test_example, config, zero_shot, lang, params
):
    try:
        instruction = create_instruction(lang=config["prefix"])
        text_example = {
            "premise": test_example["premise"],
            "hypothesis": test_example["hypothesis"],
            "label": test_example["label"],
        }

        ic_examples = []
        if not zero_shot:
            ic_examples = choose_few_shot_examples(
                train_dataset=test_data,
                few_shot_size=len(config["context"]),
                context=config["context"],
                selection_criteria="random",
                lang=params["selected_language"],
            )

        prompt, label = construct_prompt(
            instruction=instruction,
            test_example=text_example,
            ic_examples=ic_examples,
            zero_shot=zero_shot,
        )

        pred = get_prediction(
            prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572
        )
        print(pred)

        os.makedirs(
            f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
        )
        dump_predictions(
            idx=idx,
            response=pred,
            label=label,
            response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
        )

    except Exception as e:
        # Handle exceptions here
        print(f"Error processing example {idx}: {e}")


def construct_prompt(
    instruction: str,
    test_example: dict,
    zero_shot: bool,
    num_examples: int,
    lang: str,
    config: Dict[str, str],
    dataset_name: str = "xnli",
):
    if not instruction:
        print(lang)
        instruction = create_instruction(lang)

    example_prompt = PromptTemplate(
        input_variables=["premise", "hypothesis", "label"],
        template="Premise {premise}\n Hypothesis {hypothesis} \n{label}",
    )

    zero_shot_template = (
        f"""{instruction}""" + "\n Hypothesis: {hypothesis} + \n  Premise: {premise}" ""
    )
    if not zero_shot:
        try:
            test_data = load_xnli_dataset(dataset_name, lang, split="test", limit=100)
        except KeyError as e:
            raise KeyError(
                f"{lang} is not supported in {dataset_name} dataset, choose supported language in few-shot"
            )

    ic_examples = []
    if not zero_shot:
        ic_examples = choose_few_shot_examples(
            train_dataset=test_data,
            few_shot_size=num_examples,
            context=[config["context"]] * num_examples,
            selection_criteria="random",
            lang=lang,
        )

    prompt = (
        FewShotPromptTemplate(
            examples=ic_examples,
            prefix=instruction,
            example_prompt=example_prompt,
            suffix="{premise} \n{hypothesis}",
            input_variables=["hypothesis", "premise"],
        )
        if not zero_shot
        else PromptTemplate(
            input_variables=["hypothesis", "premise"], template=zero_shot_template
        )
    )

    print("lang", lang)
    print(config["input"], lang)
    if config["input"] != lang:
        test_example = _translate_example(
            example=test_example, src_language=lang, target_language=config["input"]
        )

    return prompt.format(
        hypothesis=test_example["hypothesis"], premise=test_example["premise"]
    )