Spaces:
Sleeping
Sleeping
# Install necessary packages | |
#!pip install streamlit | |
#!pip install wikipedia | |
#!pip install langchain_community | |
#!pip install sentence-transformers | |
#!pip install chromadb | |
#!pip install huggingface_hub | |
#!pip install transformers | |
import streamlit as st | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter | |
import chromadb | |
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
from huggingface_hub import login, InferenceClient | |
from sentence_transformers import CrossEncoder | |
import numpy as np | |
import random | |
import string | |
import os | |
# User variables | |
uploaded_file = st.sidebar.file_uploader("Upload your PDF", type="pdf") | |
model_name = 'mistralai/Mistral-7B-Instruct-v0.3' | |
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password") | |
# Initialize session state for error message | |
if 'error_message' not in st.session_state: | |
st.session_state.error_message = "" | |
# Function to validate token | |
def validate_token(token): | |
try: | |
# Attempt to log in with the provided token | |
login(token=token) | |
# Check if the token is valid by trying to access some data | |
HfApi().whoami() | |
return True | |
except Exception as e: | |
return False | |
# Validate the token and display appropriate message | |
if HF_TOKEN: | |
if validate_token(HF_TOKEN): | |
st.session_state.error_message = "" # Clear error message if the token is valid | |
st.sidebar.success("Token is valid!") | |
else: | |
st.session_state.error_message = "Invalid token. Please try again." | |
st.sidebar.error(st.session_state.error_message) | |
elif st.session_state.error_message: | |
st.sidebar.error(st.session_state.error_message) | |
if uploaded_file: | |
# Save the uploaded file temporarily | |
temp_file_path = os.path.join("/tmp", uploaded_file.name) | |
with open(temp_file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Load the PDF using PyPDFLoader | |
docs = PyPDFLoader(temp_file_path).load() | |
# Clean up the temporary file after loading | |
os.remove(temp_file_path) | |
else: | |
st.warning("Please upload a PDF file.") | |
# Memory for chat history | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
# Function to generate a random string for collection name | |
def generate_random_string(max_length=60): | |
if max_length > 60: | |
raise ValueError("The maximum length cannot exceed 60 characters.") | |
length = random.randint(1, max_length) | |
characters = string.ascii_letters + string.digits | |
return ''.join(random.choice(characters) for _ in range(length)) | |
collection_name = generate_random_string() | |
# Function for query expansion | |
def augment_multiple_query(query): | |
client = InferenceClient(model_name, token=HF_TOKEN) | |
content = client.chat_completion( | |
messages=[ | |
{ | |
"role": "system", | |
"content": f"""You are a helpful assistant. | |
Suggest up to five additional related questions to help them find the information they need for the provided question. | |
Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic. | |
Make sure they are complete questions, and that they are related to the original question.""" | |
}, | |
{ | |
"role": "user", | |
"content": query | |
} | |
], | |
max_tokens=500, | |
) | |
return content.choices[0].message.content.split("\n") | |
# Function to handle RAG-based question answering | |
def rag_advanced(user_query): | |
# Text Splitting | |
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0) | |
concat_texts = "".join([doc.page_content for doc in docs]) | |
character_split_texts = character_splitter.split_text(concat_texts) | |
token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256) | |
token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)] | |
# Embedding and Document Storage | |
embedding_function = SentenceTransformerEmbeddingFunction() | |
chroma_client = chromadb.Client() | |
chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function) | |
ids = [str(i) for i in range(len(token_split_texts))] | |
chroma_collection.add(ids=ids, documents=token_split_texts) | |
# Document Retrieval | |
augmented_queries = augment_multiple_query(user_query) | |
joint_query = [user_query] + augmented_queries | |
results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings']) | |
retrieved_documents = results['documents'] | |
unique_documents = list(set(doc for docs in retrieved_documents for doc in docs)) | |
# Re-Ranking | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
pairs = [[user_query, doc] for doc in unique_documents] | |
scores = cross_encoder.predict(pairs) | |
top_indices = np.argsort(scores)[::-1][:5] | |
top_documents = [unique_documents[idx] for idx in top_indices] | |
# LLM Reference | |
client = InferenceClient(model_name, token=HF_TOKEN) | |
response = "" | |
for message in client.chat_completion( | |
messages=[ | |
{ | |
"role": "system", | |
"content": f"""You are a helpful assitant. | |
You will be shown the user's questions, and the relevant information from the related documents. | |
Answer the user's question using only this information.""" | |
}, | |
{ | |
"role": "user", | |
"content": f"Questions: {user_query}. \n Information: {top_documents}" | |
} | |
], | |
max_tokens=500, | |
stream=True, | |
): | |
response += message.choices[0].delta.content | |
return response | |
# Streamlit UI | |
st.title("PDF RAG Chatbot") | |
st.markdown("Upload your PDF and enter your 🤗 token!") | |
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens") | |
# Input box for the user to type their message | |
if uploaded_file: | |
user_input = st.text_input("You: ", "", placeholder="Type your question here...") | |
if user_input: | |
response = rag_advanced(user_input) | |
st.session_state.history.append({"user": user_input, "bot": response}) | |
# Display the conversation history | |
for chat in st.session_state.history: | |
st.write(f"You: {chat['user']}") | |
st.write(f"Bot: {chat['bot']}") | |
st.markdown("-----------------") | |
st.markdown("What is this app?") | |
st.markdown("""This is a simple RAG application using PDF import. | |
The model for chat is Mistral-7B-Instruct-v0.3. | |
Main libraries: Langchain (text splitting), Chromadb (vector store) | |
This RAG uses query expansion and re-ranking to improve the quality. | |
Feel free to check the files or DM me for any questions. Thank you.""") |