File size: 5,261 Bytes
d92b8b1
 
d8c9651
 
 
 
cd3d014
 
 
 
 
d92b8b1
 
 
 
 
d8c9651
 
cd3d014
d8c9651
cd3d014
 
 
d92b8b1
d8c9651
 
cd3d014
 
 
d8c9651
 
 
cd3d014
d8c9651
 
 
cd3d014
 
d8c9651
cd3d014
d8c9651
 
 
cd3d014
 
 
 
 
 
d8c9651
 
 
 
cd3d014
d8c9651
 
 
 
 
 
d92b8b1
 
 
 
 
 
 
d8c9651
bb1b69d
d8c9651
d92b8b1
cd3d014
 
 
 
d8c9651
bb1b69d
d8c9651
 
cd3d014
 
bb1b69d
cd3d014
d8c9651
 
cd3d014
d8c9651
 
d92b8b1
 
 
 
 
 
 
 
cd3d014
d92b8b1
d8c9651
 
cd3d014
bb1b69d
 
 
d92b8b1
 
 
 
 
 
 
 
 
cd3d014
 
 
 
bb1b69d
 
 
d92b8b1
d8c9651
 
cd3d014
d8c9651
b46f895
cd3d014
d8c9651
 
d92b8b1
d8c9651
 
 
d92b8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3d014
d8c9651
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
from huggingface_hub import InferenceClient
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

# Load embeddings from a JSON file
def load_embeddings(file_path):
    logging.info(f"Loading embeddings from {file_path}")
    with open(file_path, 'r', encoding='utf-8') as file:
        embeddings = json.load(file)
        logging.info(f"Loaded {len(embeddings)} embeddings")
        return embeddings

# Function to get relevant articles based on user query
def get_relevant_documents(query, embeddings_data, model, top_k=3):
    logging.info(f"Received query: {query}")
    start_time = time.time()

    query_embedding = model.encode(query)
    similarities = []

    for i, entry in enumerate(embeddings_data):
        embedding = np.array(entry['embedding'])
        similarity = cosine_similarity([query_embedding], [embedding])[0][0]
        similarities.append((entry, similarity))
        if i % 100 == 0:  # Log every 100 iterations
            logging.debug(f"Processed {i} embeddings")

    logging.info("Sorting similarities")
    similarities.sort(key=lambda x: x[1], reverse=True)
    top_entries = [entry for entry, _ in similarities[:top_k]]
    
    end_time = time.time()
    duration = end_time - start_time
    
    logging.info(f"Query processed in {duration:.2f} seconds")
    logging.info(f"Top {top_k} documents returned with similarities: {[sim[1] for sim in similarities[:top_k]]}")

    return top_entries

# Function to format relevant documents into a string
def format_documents(documents):
    logging.info(f"Formatting {len(documents)} documents")
    formatted = ""
    for doc in documents:
        formatted += f"Relevant article: {doc['name']}\n{doc['content']}\n\n"
    return formatted

# Main chatbot function that integrates RAG
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    embeddings_data,
    tokenizer,
    model
):
    logging.info(f"New user query: {message}")
    
    start_time = time.time()

    # Search for relevant documents based on user input
    relevant_docs = get_relevant_documents(message, embeddings_data, tokenizer, model)
    retrieved_context = format_documents(relevant_docs)
    
    # Log the statistics about the retrieved documents
    logging.info(f"Total documents retrieved: {len(relevant_docs)}")
    logging.info(f"Documents: " + str([doc['name'] for doc in relevant_docs]))
    
    # Add the retrieved context as part of the system message
    system_message_with_context = system_message + "\n\n" + "Relevant documents:\n" + retrieved_context
    logging.info("System message updated with retrieved context")

    messages = [{"role": "system", "content": system_message_with_context}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})
    logging.info("Messages prepared for InferenceClient")

    client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
    
    logging.info("Sending request to InferenceClient")
    response = ""
    
    # Collect the full response instead of yielding each token
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
    
    end_time = time.time()
    total_duration = end_time - start_time
    logging.info(f"Response generated in {total_duration:.2f} seconds")
    
    return response  # Return the complete response as a string


# Load embeddings and model once at startup
embeddings_file = 'Code Civil vectorised.json'
logging.info("Starting application, loading embeddings and model")
embeddings_data = load_embeddings(embeddings_file)
embedding_model = SentenceTransformer('Lajavaness/bilingual-embedding-small', trust_remote_code=True)
logging.info("Model and embeddings loaded successfully")

# Gradio interface
demo = gr.ChatInterface(
    lambda message, history, system_message, max_tokens, temperature, top_p: respond(
        message, history, system_message, max_tokens, temperature, top_p, embeddings_data, embedding_model
    ),
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    logging.info("Launching Gradio app")
    demo.launch()