Spaces:
Sleeping
Sleeping
File size: 7,041 Bytes
9b0079f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# Install necessary packages
#!pip install streamlit
#!pip install wikipedia
#!pip install langchain_community
#!pip install sentence-transformers
#!pip install chromadb
#!pip install huggingface_hub
#!pip install transformers
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import login, InferenceClient
from sentence_transformers import CrossEncoder
import numpy as np
import random
import string
import os
# User variables
uploaded_file = st.sidebar.file_uploader("Upload your PDF", type="pdf")
model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password")
# Initialize session state for error message
if 'error_message' not in st.session_state:
st.session_state.error_message = ""
# Function to validate token
def validate_token(token):
try:
# Attempt to log in with the provided token
login(token=token)
# Check if the token is valid by trying to access some data
HfApi().whoami()
return True
except Exception as e:
return False
# Validate the token and display appropriate message
if HF_TOKEN:
if validate_token(HF_TOKEN):
st.session_state.error_message = "" # Clear error message if the token is valid
st.sidebar.success("Token is valid!")
else:
st.session_state.error_message = "Invalid token. Please try again."
st.sidebar.error(st.session_state.error_message)
elif st.session_state.error_message:
st.sidebar.error(st.session_state.error_message)
if uploaded_file:
# Save the uploaded file temporarily
temp_file_path = os.path.join("/tmp", uploaded_file.name)
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Load the PDF using PyPDFLoader
docs = PyPDFLoader(temp_file_path).load()
# Clean up the temporary file after loading
os.remove(temp_file_path)
else:
st.warning("Please upload a PDF file.")
# Memory for chat history
if "history" not in st.session_state:
st.session_state.history = []
# Function to generate a random string for collection name
def generate_random_string(max_length=60):
if max_length > 60:
raise ValueError("The maximum length cannot exceed 60 characters.")
length = random.randint(1, max_length)
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
collection_name = generate_random_string()
# Function for query expansion
def augment_multiple_query(query):
client = InferenceClient(model_name, token=HF_TOKEN)
content = client.chat_completion(
messages=[
{
"role": "system",
"content": f"""You are a helpful assistant.
Suggest up to five additional related questions to help them find the information they need for the provided question.
Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic.
Make sure they are complete questions, and that they are related to the original question."""
},
{
"role": "user",
"content": query
}
],
max_tokens=500,
)
return content.choices[0].message.content.split("\n")
# Function to handle RAG-based question answering
def rag_advanced(user_query):
# Text Splitting
character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0)
concat_texts = "".join([doc.page_content for doc in docs])
character_split_texts = character_splitter.split_text(concat_texts)
token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)
token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)]
# Embedding and Document Storage
embedding_function = SentenceTransformerEmbeddingFunction()
chroma_client = chromadb.Client()
chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function)
ids = [str(i) for i in range(len(token_split_texts))]
chroma_collection.add(ids=ids, documents=token_split_texts)
# Document Retrieval
augmented_queries = augment_multiple_query(user_query)
joint_query = [user_query] + augmented_queries
results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings'])
retrieved_documents = results['documents']
unique_documents = list(set(doc for docs in retrieved_documents for doc in docs))
# Re-Ranking
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
pairs = [[user_query, doc] for doc in unique_documents]
scores = cross_encoder.predict(pairs)
top_indices = np.argsort(scores)[::-1][:5]
top_documents = [unique_documents[idx] for idx in top_indices]
# LLM Reference
client = InferenceClient(model_name, token=HF_TOKEN)
response = ""
for message in client.chat_completion(
messages=[
{
"role": "system",
"content": f"""You are a helpful assitant.
You will be shown the user's questions, and the relevant information from the related documents.
Answer the user's question using only this information."""
},
{
"role": "user",
"content": f"Questions: {user_query}. \n Information: {top_documents}"
}
],
max_tokens=500,
stream=True,
):
response += message.choices[0].delta.content
return response
# Streamlit UI
st.title("PDF RAG Chatbot")
st.markdown("Upload your PDF and enter your 🤗 token!")
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens")
# Input box for the user to type their message
if uploaded_file:
user_input = st.text_input("You: ", "", placeholder="Type your question here...")
if user_input:
response = rag_advanced(user_input)
st.session_state.history.append({"user": user_input, "bot": response})
# Display the conversation history
for chat in st.session_state.history:
st.write(f"You: {chat['user']}")
st.write(f"Bot: {chat['bot']}")
st.markdown("-----------------")
st.markdown("What is this app?")
st.markdown("""This is a simple RAG application using PDF import.
The model for chat is Mistral-7B-Instruct-v0.3.
Main libraries: Langchain (text splitting), Chromadb (vector store)
This RAG uses query expansion and re-ranking to improve the quality.
Feel free to check the files or DM me for any questions. Thank you.""") |