my-test-space / app.py
adupav's picture
Upload app.py
bb1b69d verified
import gradio as gr
from huggingface_hub import InferenceClient
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import logging
import time
# Set up logging
logging.basicConfig(level=logging.INFO)
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# Load embeddings from a JSON file
def load_embeddings(file_path):
logging.info(f"Loading embeddings from {file_path}")
with open(file_path, 'r', encoding='utf-8') as file:
embeddings = json.load(file)
logging.info(f"Loaded {len(embeddings)} embeddings")
return embeddings
# Function to get relevant articles based on user query
def get_relevant_documents(query, embeddings_data, model, top_k=3):
logging.info(f"Received query: {query}")
start_time = time.time()
query_embedding = model.encode(query)
similarities = []
for i, entry in enumerate(embeddings_data):
embedding = np.array(entry['embedding'])
similarity = cosine_similarity([query_embedding], [embedding])[0][0]
similarities.append((entry, similarity))
if i % 100 == 0: # Log every 100 iterations
logging.debug(f"Processed {i} embeddings")
logging.info("Sorting similarities")
similarities.sort(key=lambda x: x[1], reverse=True)
top_entries = [entry for entry, _ in similarities[:top_k]]
end_time = time.time()
duration = end_time - start_time
logging.info(f"Query processed in {duration:.2f} seconds")
logging.info(f"Top {top_k} documents returned with similarities: {[sim[1] for sim in similarities[:top_k]]}")
return top_entries
# Function to format relevant documents into a string
def format_documents(documents):
logging.info(f"Formatting {len(documents)} documents")
formatted = ""
for doc in documents:
formatted += f"Relevant article: {doc['name']}\n{doc['content']}\n\n"
return formatted
# Main chatbot function that integrates RAG
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
embeddings_data,
tokenizer,
model
):
logging.info(f"New user query: {message}")
start_time = time.time()
# Search for relevant documents based on user input
relevant_docs = get_relevant_documents(message, embeddings_data, tokenizer, model)
retrieved_context = format_documents(relevant_docs)
# Log the statistics about the retrieved documents
logging.info(f"Total documents retrieved: {len(relevant_docs)}")
logging.info(f"Documents: " + str([doc['name'] for doc in relevant_docs]))
# Add the retrieved context as part of the system message
system_message_with_context = system_message + "\n\n" + "Relevant documents:\n" + retrieved_context
logging.info("System message updated with retrieved context")
messages = [{"role": "system", "content": system_message_with_context}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
logging.info("Messages prepared for InferenceClient")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
logging.info("Sending request to InferenceClient")
response = ""
# Collect the full response instead of yielding each token
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
end_time = time.time()
total_duration = end_time - start_time
logging.info(f"Response generated in {total_duration:.2f} seconds")
return response # Return the complete response as a string
# Load embeddings and model once at startup
embeddings_file = 'Code Civil vectorised.json'
logging.info("Starting application, loading embeddings and model")
embeddings_data = load_embeddings(embeddings_file)
embedding_model = SentenceTransformer('Lajavaness/bilingual-embedding-small', trust_remote_code=True)
logging.info("Model and embeddings loaded successfully")
# Gradio interface
demo = gr.ChatInterface(
lambda message, history, system_message, max_tokens, temperature, top_p: respond(
message, history, system_message, max_tokens, temperature, top_p, embeddings_data, embedding_model
),
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
logging.info("Launching Gradio app")
demo.launch()