import os
from typing import Tuple

import click
import pandas as pd
from datasets import Dataset
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_google_vertexai import ChatVertexAI
from loguru import logger
from ragas import evaluate
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import (
    answer_relevancy,
    context_precision, answer_correctness,
)
from tqdm import tqdm

from app.chroma import ChromaDenseVectorDB
from app.config.load import load_config
from app.config.models.configs import Config
from app.config.models.vertexai import VertexAIModel
from app.parsers.splitter import DocumentSplitter
from app.pipeline import LLMBundle
from app.ranking import BCEReranker
from app.splade import SpladeSparseVectorDB


def get_hash_mapping_filenames(
        config: Config,
        file_to_hash_fn: str = "file_hash_mappings.snappy.parquet",
        docid_to_hash_fn="docid_hash_mappings.snappy.parquet",
) -> Tuple[str, str]:
    file_hashes_fn = os.path.join(config.embeddings.embeddings_path, file_to_hash_fn)
    docid_hashes_fn = os.path.join(config.embeddings.embeddings_path, docid_to_hash_fn)
    return file_hashes_fn, docid_hashes_fn


@click.group()
def main():
    pass


@main.command(name="index")
@click.option(
    "-c",
    "app_config_path",
    required=True,
    help="Specifies App JavaScript configuration file (should be module exported)"
)
def create_index(app_config_path):
    config = load_config(app_config_path)

    dense_db = ChromaDenseVectorDB(
        persist_folder=str(config.embeddings.embeddings_path), config=config
    )
    splitter = DocumentSplitter(config)
    all_docs, all_hash_filename_mappings, all_hash_docid_mappings = splitter.split()

    # dense embeddings
    dense_db.generate_embeddings(docs=all_docs)

    # sparse embeddings
    sparse_db = SpladeSparseVectorDB(config)
    sparse_db.generate_embeddings(docs=all_docs)

    file_hashes_fn, docid_hashes_fn = get_hash_mapping_filenames(config)

    all_hash_filename_mappings.to_parquet(
        file_hashes_fn, compression="snappy", index=False
    )

    all_hash_docid_mappings.to_parquet(
        docid_hashes_fn, compression="snappy", index=False
    )

    logger.info("Document Embeddings Generated")


@main.command("predict")
@click.option(
    "-c",
    "app_config_path",
    required=True,
    type=click.Path(exists=True, dir_okay=False, file_okay=True),
    help="Specifies App JavaScript configuration file (should be module exported)",
)
@click.option(
    "-m",
    "model_config_path",
    required=True,
    type=click.Path(exists=True, dir_okay=False, file_okay=True),
    help="Specifies Model JavaScript configuration file (should be module exported)",
)
def predict_pipeline(app_config_path: str, model_config_path: str):
    config = load_config(app_config_path, model_config_path)

    # llm = OpenAIModel(config=config.llm.params)
    llm = VertexAIModel(config=config.llm.params)

    chain = load_qa_chain(llm=llm.model, prompt=llm.prompt)

    store = ChromaDenseVectorDB(
        persist_folder=str(config.embeddings.embeddings_path), config=config
    )
    store._load_retriever()

    reranker = BCEReranker()

    chunk_sizes = config.embeddings.chunk_sizes

    splade = SpladeSparseVectorDB(config=config)
    splade.load()

    hyde_chain = LLMChain(
        llm=llm.model,
        prompt=PromptTemplate(
            template="Write a short passage to answer the question: {question}",
            input_variables=["question"],
        ),
    )

    llm_bundle = LLMBundle(
        chain=chain,
        reranker=reranker,
        chunk_sizes=chunk_sizes,
        sparse_db=splade,
        dense_db=store,
        hyde_chain=hyde_chain,
    )

    test_dataset = pd.read_json("evaluation_dataset.json", lines=True)
    evaluate_data = {
        "question": [],
        "answer": [],
        "contexts": [],  # should be a list[list[str]]
        'ground_truth': [],
        'context_ground_truth': []
    }

    test_dataset = test_dataset.head(10)

    for idx, row in tqdm(test_dataset.iterrows()):
        output = llm_bundle.get_and_parse_response(
            query=row["question"],
            config=config,
        )
        response = output.response

        evaluate_data["question"].append(row["question"])
        evaluate_data["answer"].append(response)
        evaluate_data["contexts"].append(output.semantic_search)
        evaluate_data["ground_truth"].append(row["answer"])
        evaluate_data["context_ground_truth"].append(row["context"])

    evaluate_dataset = Dataset.from_dict(evaluate_data)

    # store the evaluation dataset

    evaluate_dataset.to_pandas().to_json("evaluation_output.json", orient="records", lines=True)


@main.command("evaluate")
def evaluate_pipeline():
    ragas_vertexai_llm = ChatVertexAI(model_name="gemini-pro")
    ragas_vertexai_llm = LangchainLLMWrapper(ragas_vertexai_llm)
    vertexai_embeddings = SentenceTransformerEmbeddings(model_name="maidalun1020/bce-embedding-base_v1")
    vertexai_embeddings = LangchainEmbeddingsWrapper(vertexai_embeddings)

    metrics = [
        # the accuracy of the generated answer when compared to the ground truth
        answer_correctness,
        # evaluates whether all the ground-truth relevant items present in the contexts are ranked higher or not
        context_precision,
        # how pertinent the generated answer is to the given prompt
        answer_relevancy,
    ]

    evaluate_dataset = pd.read_json("evaluation_output.json", lines=True)
    evaluate_dataset = Dataset.from_pandas(evaluate_dataset)

    evaluate_result = evaluate(
        dataset=evaluate_dataset,
        metrics=metrics,
        llm=ragas_vertexai_llm,
        embeddings=vertexai_embeddings,
        is_async=True
    )

    evaluate_result_df = evaluate_result.to_pandas()
    # drop the contexts, context_ground_truth
    evaluate_result_df = evaluate_result_df.drop(columns=["contexts", "context_ground_truth"])
    # print the mean for answer_correctness context_precision answer_relevancy columns
    print(evaluate_result_df.mean(numeric_only=True))
    evaluate_result_df.to_csv("evaluation_results.csv", index=False)


if __name__ == "__main__":
    main()