import os
import time
from typing import List
from qdrant_client import QdrantClient, models
from langchain_core.documents import Document
from semantic_cache.main import SemanticCache
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from Router.router import Evaluator
from langchain_openai import ChatOpenAI
# from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message
from blueprints.rag_utils import format_docs, translate
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt
from SafetyChecker import SafetyChecker
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
# import cohere
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
# from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_groq import ChatGroq
from langchain_core.runnables import RunnablePassthrough
import time 
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.document_compressors import LLMListwiseRerank
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY')

os.environ["COHERE_API_KEY"]
# HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding")
HF_EMBEDDING =  OpenAIEmbeddings(model='text-embedding-3-small', api_key = os.getenv('OPENAI_KEY'))
class LineListOutputParser(BaseOutputParser[List[str]]):
    """Output parser for a list of lines."""

    def parse(self, text: str) -> List[str]:
        lines = text.strip().split("\n")
        return list(filter(None, lines))  # Remove empty lines



def add_or_update_system_message(content: str, messages: List[dict]):
            """
            Adds a new system message at the beginning of the messages list
            :param msg: The message to be added or appended.
            :param messages: The list of message dictionaries.
            :return: The updated list of message dictionaries.
            """

            if messages and messages[0].get("role") == "system":
                messages[0]["content"] += f"{content}\n"
            else:
                # Insert at the beginning
                messages.insert(0, {"role": "system", "content": content})
            return messages

def split_context( context):
        split_index = context.find("User question")
        system_prompt = context[:split_index].strip()
        user_question = context[split_index:].strip()
        user_split_index = user_question.find("<context>")
        f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:])
        return f_system_prompt

def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')):
    meta_data_docs = []
    for doc in docs:
        meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)]
        meta_data_docs.append(meta_data_doc)
    return meta_data_docs


def search_with_filter(query, vector_store, k, headers):
    conditions = []
    
    # Xử lý điều kiện theo số lượng headers
    if len(headers) == 1:
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_1",
                match=models.MatchValue(
                    value=headers[0]
                ),
            )
        )
    elif len(headers) == 2:
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_1",
                match=models.MatchValue(
                    value=headers[0]
                ),
            )
        )
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_2",
                match=models.MatchValue(
                    value=headers[1]
                ),
            )
        )
    elif len(headers) == 3:
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_1",
                match=models.MatchValue(
                    value=headers[0]
                ),
            )
        )
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_2",
                match=models.MatchValue(
                    value=headers[1]
                ),
            )
        )
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_3",
                match=models.MatchValue(
                    value=headers[2]
                ),
            )
        )

    # Thực hiện truy vấn với các điều kiện
    single_result = vector_store.similarity_search(
        query=query,
        k=k,
        filter=models.Filter(
            must=conditions
        ),
    )
    
    return single_result

def get_relevant_documents(documents: List[Document], limit: int) -> List[Document]:
    result = []
    seen = set()
    for doc in documents:
        if doc.page_content not in seen:
            result.append(doc)
            seen.add(doc.page_content)
        if len(result) == limit:
            break
    return result




if __name__ == "__main__":

    client = QdrantClient(
        url="http://localhost:6333"
    )
    stsv = Qdrant(client, collection_name="sotaysinhvien_filter", embeddings= HF_EMBEDDING)
    stsv_db = stsv.as_retriever(search_kwargs={'k': 10})

    gthv = Qdrant(client, collection_name="gioithieuhocvien_filter", embeddings= HF_EMBEDDING)
    gthv_db = gthv.as_retriever(search_kwargs={'k': 10})

    ttts = Qdrant(client, collection_name="thongtintuyensinh_filter", embeddings= HF_EMBEDDING)
    ttts_db = ttts.as_retriever(search_kwargs={'k': 10})

    import pickle
    with open('data/sotaysinhvien_filter.pkl', 'rb') as f:
        sotaysinhvien = pickle.load(f)
    with open('data/thongtintuyensinh_filter.pkl', 'rb') as f:
        thongtintuyensinh = pickle.load(f)
    with open('data/gioithieuhocvien_filter.pkl', 'rb') as f:
        gioithieuhocvien = pickle.load(f)


    retriever_bm25_tuyensinh = BM25SRetriever.from_documents(thongtintuyensinh, k= 10, save_directory = "data/bm25s/ttts")
    retriever_bm25_sotay = BM25SRetriever.from_documents(sotaysinhvien, k= 10,  save_directory = "data/bm25s/stsv")
    retriever_bm25_hocvien = BM25SRetriever.from_documents(gioithieuhocvien, k= 10, save_directory = "data/bm25s/gthv" )


    # reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
    llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_3'))
    llm2 = ChatGroq(model_name="llama-3.1-70b-versatile", temperature=1,api_key= os.getenv('llm_api_8'))
    
    output_parser = LineListOutputParser()
    llm_chain = QUERY_PROMPT | llm | output_parser

    # messages = [
    #     {"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"}
    # ]
    # ###########################
    cache = SemanticCache()
    another_chain = ( chitchat_prompt | llm2 | StrOutputParser())
    safe_chain = ( safe_prompt | llm2 | StrOutputParser())
    cache_chain = ( cache_prompt | llm2 | StrOutputParser())


    # def duy_phen():
    while 1: 
        body = {}

        user_message = input("Nhập câu hỏi nào!: ")
        
        
        checker = SafetyChecker()
        safety_result = checker.check_safety(translate(user_message))
        print("Safety check :" ,safety_result)
        if safety_result != 'safe' :
                    print("UNSAFE")
                    response = safe_chain.invoke({'meaning': f'{safety_result}'})
                    print(response)
                    exit()
        evaluator = Evaluator(llm="llama3-70b", prompt=evaluator_intent)
        output = evaluator.classify_text(user_message)
        print(output.result)
        retriever = None  # or assign a specific default retriever if applicable
        db = None  # initialize db as well if it is used later in the code
                # print(output.result)
        source = None
        cache_result =cache.checker(user_message)
        if cache_result is not None:
            print("###Cache hit!###")
            response = cache_chain.invoke({"question": f'{user_message}', "content": f"{cache_result}"})
            print(response)

        if output and  output.result == 'OUT_OF_SCOPE' :
                    print('OUT OF SCOPE')
                    # print(body)
                    response = another_chain.invoke({"question": f"{user_message}"})
                    print(response)
                    
        elif output and  output.result == 'ASK_QUYDINH'  :
                        print('SO TAY SINH VIEN DB') 
                        retriever = stsv_db
                        retriever_bm25 = retriever_bm25_sotay
                        source = stsv
                
                        # db = sotaysinhvien
        elif output and  output.result == 'ASK_HOCVIEN' :
                        print('GIOI THIEU HOC VIEN DB')  
                        retriever = gthv_db
                        retriever_bm25 = retriever_bm25_hocvien
                        source = gthv
                      
                        # db = gioithieuhocvien
        elif output and  output.result == 'ASK_TUYENSINH'  :
                        print('THONG TIN TUYEN SINH DB') 
                        retriever = ttts_db
                        retriever_bm25 = retriever_bm25_tuyensinh
                        source = ttts
                  
                        # db = thongtintuyensinh

        
        if retriever is not None:
            
            # retriever_multi = MultiQueryRetriever(
            #                 retriever=retriever, llm_chain=llm_chain, parser_key="lines"
            #             ) 
            start_time = time.time()
            ensemble_retriever = EnsembleRetriever(
                    retrievers=[retriever_bm25, retriever], weights=[0.5, 0.5])
            
            # compressor = LLMChainExtractor.from_llm(llm)
            # _filter = LLMChainFilter.from_llm(llm)
            # embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5)

            # compression = ContextualCompressionRetriever(
            #             base_compressor=_filter2, base_retriever=ensemble_retriever
            #             )

            reranker = LLMListwiseRerank.from_llm(
                llm=llm, top_n=5
            )           
            tailieu = ensemble_retriever.invoke(f"{user_message}")
            docs = reranker.compress_documents(tailieu, user_message)
            end_time = time.time()
#################### Filter lại ở đây -> add more documents liên quan hơn ######################### 
            
            # docs = compression.invoke(f"{user_message}")
            # print(docs)
            
            
            meta_data_docs = extract_metadata(docs)
         
            full_result = []
            for meta_data_doc in meta_data_docs:
              
                result = search_with_filter(user_message, source, 10, meta_data_doc)
        
                for i in result: 
                    full_result.append(i)
            print("Context liên quan" + '\n')
            print(full_result)
            
            # rag_chain = (
            #                 {"context": compression | format_docs, "question": RunnablePassthrough()}
            #                 | basic_template | llm2 | StrOutputParser()
            #             )
            result_final = get_relevant_documents(full_result, 10)
            
            context = format_docs(result_final)

            best_chain = ( basic_template | llm2 | StrOutputParser())
            
            best_result = best_chain.invoke({"question": f'{user_message}', "context": f"{context}"})
            
            print(f'Câu trả lời tối ưu nhất: {best_result}')


            
            print(f'TIME USING : {end_time -start_time}')
        else:
            print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')

    # duy_phen()