import argparse
import json
import os

from typing import Optional, Tuple
from tqdm.auto import tqdm

import torch

from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

def check_base_path(path: str) -> Optional[str]:
    if path is not None:
        base_path = os.path.basename(path)
        if os.path.exists(base_path):
            return path
        else:
            raise Exception(f'Path not found {base_path}')
    return path


def parse_args():
    DEFAULT_MODEL_ID = 'EmbeddingStudio/query-parser-falcon-7b-instruct'
    DEFAULT_DATASET = 'EmbeddingStudio/query-parsing-instructions-falcon'
    DEFAULT_SPLIT = 'test'
    DEFAULT_INSTRUCTION_FIELD = 'text'
    DEFAULT_RESPONSE_DELIMITER = '## Response:\n'
    DEFAULT_CATEGORY_DELIMITER = '## Category:'
    DEFAULT_OUTPUT_PATH = f'{DEFAULT_MODEL_ID.split("/")[-1]}-test.json'

    parser = argparse.ArgumentParser(description='EmbeddingStudio script for testing Zero-Shot Search Query Parsers')
    parser.add_argument("--model-id",
                        help=f"Huggingface model ID (default: {DEFAULT_MODEL_ID})",
                        default=DEFAULT_MODEL_ID,
                        type=str,
    )
    parser.add_argument("--dataset-name",
                        help=f"Huggingface dataset name which contains instructions (default: {DEFAULT_DATASET})",
                        default=DEFAULT_DATASET,
                        type=str,
    )
    parser.add_argument("--dataset-split",
                        help=f"Huggingface dataset split name (default: {DEFAULT_SPLIT})",
                        default=DEFAULT_SPLIT,
                        type=str,
    )
    parser.add_argument("--dataset-instructions-field",
                        help=f"Huggingface dataset field with instructions (default: {DEFAULT_INSTRUCTION_FIELD})",
                        default=DEFAULT_INSTRUCTION_FIELD,
                        type=str,
    )
    parser.add_argument("--instructions-response-delimiter",
                        help=f"Instruction response delimiter (default: {DEFAULT_RESPONSE_DELIMITER})",
                        default=DEFAULT_RESPONSE_DELIMITER,
                        type=str,
    )
    parser.add_argument("--instructions-category-delimiter",
                        help=f"Instruction category name delimiter (default: {DEFAULT_CATEGORY_DELIMITER})",
                        default=DEFAULT_CATEGORY_DELIMITER,
                        type=str,
    )

    parser.add_argument("--output",
                        help=f"JSON file with test results (default: {DEFAULT_OUTPUT_PATH})",
                        default=DEFAULT_OUTPUT_PATH,
                        type=check_base_path,
    )
    args = parser.parse_args()
    return args


def load_model(model_id: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True,
        add_prefix_space=True,
        use_fast=False,
    )
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": 0})
    return tokenizer, model


@torch.no_grad()
def predict(
        tokenizer: AutoTokenizer,
        model: AutoModelForCausalLM,
        dataset: DatasetDict,
        index: int,
        field_name: str = 'text',
        response_delimiter: str = '## Response:\n',
        category_delimiter: str = '## Category: '
) -> Tuple[dict, dict, str]:
    input_text = dataset[index][field_name].split(response_delimiter)[0] + response_delimiter
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    real = json.loads(dataset[index][field_name].split(response_delimiter)[-1])
    category = dataset[index][field_name].split(category_delimiter)[-1].split('\n')[0]

    # Generating text
    output = model.generate(input_ids.to('cuda'),
                            max_new_tokens=1000,
                            do_sample=True,
                            temperature=0.05,
                            pad_token_id=50256
    )
    parsed = json.loads(tokenizer.decode(output[0], skip_special_tokens=True).split(response_delimiter)[-1])

    return [parsed, real, category]


@torch.no_grad()
def test_model(model_id: str,
               dataset_name: str,
               split_name: str,
               field_name: str,
               response_delimiter: str,
               category_delimiter: str,
               output_path: str,

):
    dataset = load_dataset(dataset_name, split=split_name)
    tokenizer, model = load_model(model_id)
    model.eval()

    test_results = []
    for index in tqdm(range(len(dataset[split_name]))):
        try:
            test_results.append(predict(tokenizer, model, dataset[split_name], index, field_name, response_delimiter, category_delimiter))
        except Exception as e:
            continue

    with open(output_path, 'w') as f:
        json.dump(test_results)



if __name__ == '__main__':
    args = parse_args()
    test_model(
        args.model_id,
        args.dataset_name,
        args.dataset_split,
        args.dataset_instructions_field,
        args.instructions_response_delimiter,
        args.instructions_category_delimiter,
        args.output
    )