import gradio as gr
from huggingface_hub import InferenceClient
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

# Load embeddings from a JSON file
def load_embeddings(file_path):
    logging.info(f"Loading embeddings from {file_path}")
    with open(file_path, 'r', encoding='utf-8') as file:
        embeddings = json.load(file)
        logging.info(f"Loaded {len(embeddings)} embeddings")
        return embeddings

# Function to get relevant articles based on user query
def get_relevant_documents(query, embeddings_data, model, top_k=3):
    logging.info(f"Received query: {query}")
    start_time = time.time()

    query_embedding = model.encode(query)
    similarities = []

    for i, entry in enumerate(embeddings_data):
        embedding = np.array(entry['embedding'])
        similarity = cosine_similarity([query_embedding], [embedding])[0][0]
        similarities.append((entry, similarity))
        if i % 100 == 0:  # Log every 100 iterations
            logging.debug(f"Processed {i} embeddings")

    logging.info("Sorting similarities")
    similarities.sort(key=lambda x: x[1], reverse=True)
    top_entries = [entry for entry, _ in similarities[:top_k]]
    
    end_time = time.time()
    duration = end_time - start_time
    
    logging.info(f"Query processed in {duration:.2f} seconds")
    logging.info(f"Top {top_k} documents returned with similarities: {[sim[1] for sim in similarities[:top_k]]}")

    return top_entries

# Function to format relevant documents into a string
def format_documents(documents):
    logging.info(f"Formatting {len(documents)} documents")
    formatted = ""
    for doc in documents:
        formatted += f"Relevant article: {doc['name']}\n{doc['content']}\n\n"
    return formatted

# Main chatbot function that integrates RAG
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    embeddings_data,
    tokenizer,
    model
):
    logging.info(f"New user query: {message}")
    
    start_time = time.time()

    # Search for relevant documents based on user input
    relevant_docs = get_relevant_documents(message, embeddings_data, tokenizer, model)
    retrieved_context = format_documents(relevant_docs)
    
    # Log the statistics about the retrieved documents
    logging.info(f"Total documents retrieved: {len(relevant_docs)}")
    logging.info(f"Documents: " + str([doc['name'] for doc in relevant_docs]))
    
    # Add the retrieved context as part of the system message
    system_message_with_context = system_message + "\n\n" + "Relevant documents:\n" + retrieved_context
    logging.info("System message updated with retrieved context")

    messages = [{"role": "system", "content": system_message_with_context}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})
    logging.info("Messages prepared for InferenceClient")

    client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
    
    logging.info("Sending request to InferenceClient")
    response = ""
    
    # Collect the full response instead of yielding each token
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
    
    end_time = time.time()
    total_duration = end_time - start_time
    logging.info(f"Response generated in {total_duration:.2f} seconds")
    
    return response  # Return the complete response as a string


# Load embeddings and model once at startup
embeddings_file = 'Code Civil vectorised.json'
logging.info("Starting application, loading embeddings and model")
embeddings_data = load_embeddings(embeddings_file)
embedding_model = SentenceTransformer('Lajavaness/bilingual-embedding-small', trust_remote_code=True)
logging.info("Model and embeddings loaded successfully")

# Gradio interface
demo = gr.ChatInterface(
    lambda message, history, system_message, max_tokens, temperature, top_p: respond(
        message, history, system_message, max_tokens, temperature, top_p, embeddings_data, embedding_model
    ),
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    logging.info("Launching Gradio app")
    demo.launch()