import copy
import os
from pathlib import Path
from typing import Union, Any, Optional, List, Dict, Tuple, ClassVar, Collection

import tiktoken
from langchain.chains import create_extraction_chain
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
    map_rerank_prompt
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.retrievers import MultiQueryRetriever
from langchain.schema import Document
from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K
from langchain_community.vectorstores.faiss import FAISS
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from tqdm import tqdm

from document_qa.grobid_processors import GrobidProcessor


def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]:
    return [
        (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3])
        for result in zip(
            results["documents"][0],
            results["metadatas"][0],
            results["distances"][0],
            results["embeddings"][0],
        )
    ]


class TextMerger:
    """
    This class tries to replicate the RecursiveTextSplitter from LangChain, to preserve and merge the
    coordinate information from the PDF document.
    """

    def __init__(self, model_name=None, encoding_name="gpt2"):
        if model_name is not None:
            self.enc = tiktoken.encoding_for_model(model_name)
        else:
            self.enc = tiktoken.get_encoding(encoding_name)

    def encode(self, text, allowed_special=set(), disallowed_special="all"):
        return self.enc.encode(
            text,
            allowed_special=allowed_special,
            disallowed_special=disallowed_special,
        )

    def merge_passages(self, passages, chunk_size, tolerance=0.2):
        new_passages = []
        new_coordinates = []
        current_texts = []
        current_coordinates = []
        for idx, passage in enumerate(passages):
            text = passage['text']
            coordinates = passage['coordinates']
            current_texts.append(text)
            current_coordinates.append(coordinates)

            accumulated_text = " ".join(current_texts)

            encoded_accumulated_text = self.encode(accumulated_text)

            if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
                if len(current_texts) > 1:
                    new_passages.append(current_texts[:-1])
                    new_coordinates.append(current_coordinates[:-1])
                    current_texts = [current_texts[-1]]
                    current_coordinates = [current_coordinates[-1]]
                else:
                    new_passages.append(current_texts)
                    new_coordinates.append(current_coordinates)
                    current_texts = []
                    current_coordinates = []

            elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
                new_passages.append(current_texts)
                new_coordinates.append(current_coordinates)
                current_texts = []
                current_coordinates = []

        if len(current_texts) > 0:
            new_passages.append(current_texts)
            new_coordinates.append(current_coordinates)

        new_passages_struct = []
        for i, passages in enumerate(new_passages):
            text = " ".join(passages)
            coordinates = ";".join(new_coordinates[i])

            new_passages_struct.append(
                {
                    "text": text,
                    "coordinates": coordinates,
                    "type": "aggregated chunks",
                    "section": "mixed",
                    "subSection": "mixed"
                }
            )

        return new_passages_struct


class BaseRetrieval:

    def __init__(
            self,
            persist_directory: Path,
            embedding_function
    ):
        self.embedding_function = embedding_function
        self.persist_directory = persist_directory


class AdvancedVectorStoreRetriever(VectorStoreRetriever):
    allowed_search_types: ClassVar[Collection[str]] = (
        "similarity",
        "similarity_score_threshold",
        "mmr",
        "similarity_with_embeddings"
    )

    def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        if self.search_type == "similarity":
            docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
        elif self.search_type == "similarity_score_threshold":
            docs_and_similarities = (
                self.vectorstore.similarity_search_with_relevance_scores(
                    query, **self.search_kwargs
                )
            )
            for doc, similarity in docs_and_similarities:
                if '__similarity' not in doc.metadata.keys():
                    doc.metadata['__similarity'] = similarity

            docs = [doc for doc, _ in docs_and_similarities]
        elif self.search_type == "mmr":
            docs = self.vectorstore.max_marginal_relevance_search(
                query, **self.search_kwargs
            )
        elif self.search_type == "similarity_with_embeddings":
            docs_scores_and_embeddings = (
                self.vectorstore.advanced_similarity_search(
                    query, **self.search_kwargs
                )
            )

            for doc, score, embeddings in docs_scores_and_embeddings:
                if '__embeddings' not in doc.metadata.keys():
                    doc.metadata['__embeddings'] = embeddings
                if '__similarity' not in doc.metadata.keys():
                    doc.metadata['__similarity'] = score

            docs = [doc for doc, _, _ in docs_scores_and_embeddings]
        else:
            raise ValueError(f"search_type of {self.search_type} not allowed.")
        return docs


class AdvancedVectorStore(VectorStore):
    def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever:
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._get_retriever_tags())
        return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)


class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @xor_args(("query_texts", "query_embeddings"))
    def __query_collection(
            self,
            query_texts: Optional[List[str]] = None,
            query_embeddings: Optional[List[List[float]]] = None,
            n_results: int = 4,
            where: Optional[Dict[str, str]] = None,
            where_document: Optional[Dict[str, str]] = None,
            **kwargs: Any,
    ) -> List[Document]:
        """Query the chroma collection."""
        try:
            import chromadb  # noqa: F401
        except ImportError:
            raise ValueError(
                "Could not import chromadb python package. "
                "Please install it with `pip install chromadb`."
            )
        return self._collection.query(
            query_texts=query_texts,
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=where,
            where_document=where_document,
            **kwargs,
        )

    def advanced_similarity_search(
            self,
            query: str,
            k: int = DEFAULT_K,
            filter: Optional[Dict[str, str]] = None,
            **kwargs: Any,
    ) -> [List[Document], float, List[float]]:
        docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter)
        return docs_scores_and_embeddings

    def similarity_search_with_scores_and_embeddings(
            self,
            query: str,
            k: int = DEFAULT_K,
            filter: Optional[Dict[str, str]] = None,
            where_document: Optional[Dict[str, str]] = None,
            **kwargs: Any,
    ) -> List[Tuple[Document, float, List[float]]]:

        if self._embedding_function is None:
            results = self.__query_collection(
                query_texts=[query],
                n_results=k,
                where=filter,
                where_document=where_document,
                include=['metadatas', 'documents', 'embeddings', 'distances']
            )
        else:
            query_embedding = self._embedding_function.embed_query(query)
            results = self.__query_collection(
                query_embeddings=[query_embedding],
                n_results=k,
                where=filter,
                where_document=where_document,
                include=['metadatas', 'documents', 'embeddings', 'distances']
            )

        return _results_to_docs_scores_and_embeddings(results)


class FAISSAdvancedRetrieval(FAISS):
    pass


class NER_Retrival(VectorStore):
    """
    This class implement a retrieval based on NER models.
    This is an alternative retrieval to embeddings that relies on extracted entities.
    """
    pass


engines = {
    'chroma': ChromaAdvancedRetrieval,
    'faiss': FAISSAdvancedRetrieval,
    'ner': NER_Retrival
}


class DataStorage:
    embeddings_dict = {}
    embeddings_map_from_md5 = {}
    embeddings_map_to_md5 = {}

    def __init__(
            self,
            embedding_function,
            root_path: Path = None,
            engine=ChromaAdvancedRetrieval,
    ) -> None:
        self.root_path = root_path
        self.engine = engine
        self.embedding_function = embedding_function

        if root_path is not None:
            self.embeddings_root_path = root_path
            if not os.path.exists(root_path):
                os.makedirs(root_path)
            else:
                self.load_embeddings(self.embeddings_root_path)

    def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None:
        """
        Load the vector storage assuming they are all persisted and stored in a single directory.
        The root path of the embeddings containing one data store for each document in each subdirectory
        """

        embeddings_directories = [f for f in os.scandir(embeddings_root_path) if f.is_dir()]

        if len(embeddings_directories) == 0:
            print("No available embeddings")
            return

        for embedding_document_dir in embeddings_directories:
            self.embeddings_dict[embedding_document_dir.name] = self.engine(
                persist_directory=embedding_document_dir.path,
                embedding_function=self.embedding_function
            )

            filename_list = list(Path(embedding_document_dir).glob('*.storage_filename'))
            if filename_list:
                filenam = filename_list[0].name.replace(".storage_filename", "")
                self.embeddings_map_from_md5[embedding_document_dir.name] = filenam
                self.embeddings_map_to_md5[filenam] = embedding_document_dir.name

        print("Embedding loaded: ", len(self.embeddings_dict.keys()))

    def get_loaded_embeddings_ids(self):
        return list(self.embeddings_dict.keys())

    def get_md5_from_filename(self, filename):
        return self.embeddings_map_to_md5[filename]

    def get_filename_from_md5(self, md5):
        return self.embeddings_map_from_md5[md5]

    def embed_document(self, doc_id, texts, metadatas):
        if doc_id not in self.embeddings_dict.keys():
            self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
                                                                  embedding=self.embedding_function,
                                                                  metadatas=metadatas,
                                                                  collection_name=doc_id)
        else:
            # Workaround Chroma (?) breaking change
            self.embeddings_dict[doc_id].delete_collection()
            self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
                                                                  embedding=self.embedding_function,
                                                                  metadatas=metadatas,
                                                                  collection_name=doc_id)

        self.embeddings_root_path = None


class DocumentQAEngine:
    llm = None
    qa_chain_type = None

    default_prompts = {
        'stuff': stuff_prompt,
        'refine': refine_prompts,
        "map_reduce": map_reduce_prompt,
        "map_rerank": map_rerank_prompt
    }

    def __init__(self,
                 llm,
                 data_storage: DataStorage,
                 qa_chain_type="stuff",
                 grobid_url=None,
                 memory=None
                 ):

        self.llm = llm
        self.memory = memory
        self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
        self.text_merger = TextMerger()
        self.data_storage = data_storage

        if grobid_url:
            self.grobid_processor = GrobidProcessor(grobid_url)

    def query_document(
            self,
            query: str,
            doc_id,
            output_parser=None,
            context_size=4,
            extraction_schema=None,
            verbose=False
    ) -> (Any, str):
        # self.load_embeddings(self.embeddings_root_path)

        if verbose:
            print(query)

        response, coordinates = self._run_query(doc_id, query, context_size=context_size)
        response = response['output_text'] if 'output_text' in response else response

        if verbose:
            print(doc_id, "->", response)

        if output_parser:
            try:
                return self._parse_json(response, output_parser), response
            except Exception as oe:
                print("Failing to parse the response", oe)
                return None, response, coordinates
        elif extraction_schema:
            try:
                chain = create_extraction_chain(extraction_schema, self.llm)
                parsed = chain.run(response)
                return parsed, response, coordinates
            except Exception as oe:
                print("Failing to parse the response", oe)
                return None, response, coordinates
        else:
            return None, response, coordinates

    def query_storage(self, query: str, doc_id, context_size=4) -> (List[Document], list):
        """
        Returns the context related to a given query
        """
        documents, coordinates = self._get_context(doc_id, query, context_size)

        context_as_text = [doc.page_content for doc in documents]
        return context_as_text, coordinates

    def query_storage_and_embeddings(self, query: str, doc_id, context_size=4):
        """
        Returns both the context and the embedding information from a given query
        """
        db = self.data_storage.embeddings_dict[doc_id]
        retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
        relevant_documents = retriever.get_relevant_documents(query)

        context_as_text = [doc.page_content for doc in relevant_documents]
        return context_as_text

        # chroma_collection.get(include=['embeddings'])['embeddings']

    def _parse_json(self, response, output_parser):
        system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \
                         "that can process text and transform it to JSON."
        human_message = """Transform the text between three double quotes in JSON.\n\n\n\n
        {format_instructions}\n\nText: \"\"\"{text}\"\"\""""

        system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
        human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)

        prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

        results = self.llm(
            prompt_template.format_prompt(
                text=response,
                format_instructions=output_parser.get_format_instructions()
            ).to_messages()
        )
        parsed_output = output_parser.parse(results.content)

        return parsed_output

    def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
        relevant_documents = self._get_context(doc_id, query, context_size)
        relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
                                         for doc in
                                         relevant_documents]
        response = self.chain.run(input_documents=relevant_documents,
                                  question=query)

        if self.memory:
            self.memory.save_context({"input": query}, {"output": response})
        return response, relevant_document_coordinates

    def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
        db = self.data_storage.embeddings_dict[doc_id]
        retriever = db.as_retriever(search_kwargs={"k": context_size})
        relevant_documents = retriever.get_relevant_documents(query)
        relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
                                         for doc in
                                         relevant_documents]
        if self.memory and len(self.memory.buffer_as_messages) > 0:
            relevant_documents.append(
                Document(
                    page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format(
                        self.memory.buffer_as_str))
            )
        return relevant_documents, relevant_document_coordinates

    def get_full_context_by_document(self, doc_id):
        """
        Return the full context from the document
        """
        db = self.data_storage.embeddings_dict[doc_id]
        docs = db.get()
        return docs['documents']

    def _get_context_multiquery(self, doc_id, query, context_size=4):
        db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
        multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
        relevant_documents = multi_query_retriever.get_relevant_documents(query)
        return relevant_documents

    def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
        """
        Extract text from documents using Grobid.
        - if chunk_size is < 0, keeps each paragraph separately
        - if chunk_size > 0, aggregate all paragraphs and split them again using an approximate chunk size
        """
        if verbose:
            print("File", pdf_file_path)
        filename = Path(pdf_file_path).stem
        coordinates = True  # if chunk_size == -1 else False
        structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)

        biblio = structure['biblio']
        biblio['filename'] = filename.replace(" ", "_")

        if verbose:
            print("Generating embeddings for:", hash, ", filename: ", filename)

        texts = []
        metadatas = []
        ids = []

        if chunk_size > 0:
            new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
        else:
            new_passages = structure['passages']

        for passage in new_passages:
            biblio_copy = copy.copy(biblio)
            if len(str.strip(passage['text'])) > 0:
                texts.append(passage['text'])

                biblio_copy['type'] = passage['type']
                biblio_copy['section'] = passage['section']
                biblio_copy['subSection'] = passage['subSection']
                biblio_copy['coordinates'] = passage['coordinates']
                metadatas.append(biblio_copy)

                # ids.append(passage['passage_id'])

            ids = [id for id, t in enumerate(new_passages)]

        return texts, metadatas, ids

    def create_memory_embeddings(
            self,
            pdf_path,
            doc_id=None,
            chunk_size=500,
            perc_overlap=0.1
    ):
        texts, metadata, ids = self.get_text_from_document(
            pdf_path,
            chunk_size=chunk_size,
            perc_overlap=perc_overlap)
        if doc_id:
            hash = doc_id
        else:
            hash = metadata[0]['hash']

        self.data_storage.embed_document(hash, texts, metadata)

        return hash

    def create_embeddings(
            self,
            pdfs_dir_path: Path,
            chunk_size=500,
            perc_overlap=0.1,
            include_biblio=False
    ):
        input_files = []
        for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
            for file_ in files:
                if not (file_.lower().endswith(".pdf")):
                    continue
                input_files.append(os.path.join(root, file_))

        for input_file in tqdm(input_files, total=len(input_files), unit='document',
                               desc="Grobid + embeddings processing"):

            md5 = self.calculate_md5(input_file)
            data_path = os.path.join(self.data_storage.embeddings_root_path, md5)

            if os.path.exists(data_path):
                print(data_path, "exists. Skipping it ")
                continue
            # include = ["biblio"] if include_biblio else []
            texts, metadata, ids = self.get_text_from_document(
                input_file,
                chunk_size=chunk_size,
                perc_overlap=perc_overlap)
            filename = metadata[0]['filename']

            vector_db_document = Chroma.from_texts(texts,
                                                   metadatas=metadata,
                                                   embedding=self.embedding_function,
                                                   persist_directory=data_path)
            vector_db_document.persist()

            with open(os.path.join(data_path, filename + ".storage_filename"), 'w') as fo:
                fo.write("")

    @staticmethod
    def calculate_md5(input_file: Union[Path, str]):
        import hashlib
        md5_hash = hashlib.md5()
        with open(input_file, 'rb') as fi:
            md5_hash.update(fi.read())
        return md5_hash.hexdigest().upper()