File size: 7,065 Bytes
9b0079f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529a843
 
9b0079f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529a843
 
 
 
9b0079f
 
 
 
529a843
9b0079f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Install necessary packages
#!pip install streamlit
#!pip install wikipedia
#!pip install langchain_community
#!pip install sentence-transformers
#!pip install chromadb
#!pip install huggingface_hub
#!pip install transformers

import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import login, InferenceClient
from sentence_transformers import CrossEncoder
import numpy as np
import random
import string
import tempfile


# User variables
uploaded_file = st.sidebar.file_uploader("Upload your PDF", type="pdf")
model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
HF_TOKEN = st.sidebar.text_input("Enter your Hugging Face token:", "", type="password")


# Initialize session state for error message
if 'error_message' not in st.session_state:
    st.session_state.error_message = ""

# Function to validate token
def validate_token(token):
    try:
        # Attempt to log in with the provided token
        login(token=token)
        # Check if the token is valid by trying to access some data
        HfApi().whoami()
        return True
    except Exception as e:
        return False

# Validate the token and display appropriate message
if HF_TOKEN:
    if validate_token(HF_TOKEN):
        st.session_state.error_message = ""  # Clear error message if the token is valid
        st.sidebar.success("Token is valid!")
    else:
        st.session_state.error_message = "Invalid token. Please try again."
        st.sidebar.error(st.session_state.error_message)
elif st.session_state.error_message:
    st.sidebar.error(st.session_state.error_message)

if uploaded_file:
    # Create a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
        temp_file.write(uploaded_file.getbuffer())
        temp_file_path = temp_file.name

    # Load the PDF using PyPDFLoader
    docs = PyPDFLoader(temp_file_path).load()

    # The temporary file will be automatically deleted when the application stops
else:
    st.warning("Please upload a PDF file.")


# Memory for chat history
if "history" not in st.session_state:
    st.session_state.history = []

# Function to generate a random string for collection name
def generate_random_string(max_length=60):
    if max_length > 60:
        raise ValueError("The maximum length cannot exceed 60 characters.")
    length = random.randint(1, max_length)
    characters = string.ascii_letters + string.digits
    return ''.join(random.choice(characters) for _ in range(length))

collection_name = generate_random_string()

# Function for query expansion
def augment_multiple_query(query):
    client = InferenceClient(model_name, token=HF_TOKEN)
    content = client.chat_completion(
        messages=[
            {
                "role": "system",
                "content": f"""You are a helpful assistant.
                Suggest up to five additional related questions to help them find the information they need for the provided question.
                Suggest only short questions without compound sentences. Suggest a variety of questions that cover different aspects of the topic.
                Make sure they are complete questions, and that they are related to the original question."""
            },
            {
                "role": "user",
                "content": query
            }
        ],
        max_tokens=500,
    )
    return content.choices[0].message.content.split("\n")

# Function to handle RAG-based question answering
def rag_advanced(user_query):


    # Text Splitting
    character_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0)
    concat_texts = "".join([doc.page_content for doc in docs])
    character_split_texts = character_splitter.split_text(concat_texts)

    token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)
    token_split_texts = [text for text in character_split_texts for text in token_splitter.split_text(text)]

    # Embedding and Document Storage
    embedding_function = SentenceTransformerEmbeddingFunction()
    chroma_client = chromadb.Client()
    chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function)

    ids = [str(i) for i in range(len(token_split_texts))]
    chroma_collection.add(ids=ids, documents=token_split_texts)

    # Document Retrieval
    augmented_queries = augment_multiple_query(user_query)
    joint_query = [user_query] + augmented_queries

    results = chroma_collection.query(query_texts=joint_query, n_results=5, include=['documents', 'embeddings'])
    retrieved_documents = results['documents']

    unique_documents = list(set(doc for docs in retrieved_documents for doc in docs))

    # Re-Ranking
    cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
    pairs = [[user_query, doc] for doc in unique_documents]
    scores = cross_encoder.predict(pairs)

    top_indices = np.argsort(scores)[::-1][:5]
    top_documents = [unique_documents[idx] for idx in top_indices]

    # LLM Reference
    client = InferenceClient(model_name, token=HF_TOKEN)
    response = ""
    for message in client.chat_completion(
        messages=[
            {
                "role": "system",
                "content": f"""You are a helpful assitant.
                You will be shown the user's questions, and the relevant information from the related documents.
                Answer the user's question using only this information."""
            },
            {
                "role": "user",
                "content": f"Questions: {user_query}. \n Information: {top_documents}"
            }
        ],
        max_tokens=500,
        stream=True,
    ):
        response += message.choices[0].delta.content

    return response

# Streamlit UI
st.title("PDF RAG Chatbot")
st.markdown("Upload your PDF and enter your 🤗 token!")
st.link_button("Get Token Here", "https://huggingface.co/settings/tokens")

# Input box for the user to type their message
if uploaded_file:
    user_input = st.text_input("You: ", "", placeholder="Type your question here...")
    if user_input:
        response = rag_advanced(user_input)
        st.session_state.history.append({"user": user_input, "bot": response})

# Display the conversation history
for chat in st.session_state.history:
    st.write(f"You: {chat['user']}")
    st.write(f"Bot: {chat['bot']}")


st.markdown("-----------------")
st.markdown("What is this app?")
st.markdown("""This is a simple RAG application using PDF import.  
The model for chat is Mistral-7B-Instruct-v0.3.  
Main libraries: Langchain (text splitting), Chromadb (vector store)  
This RAG uses query expansion and re-ranking to improve the quality.  
Feel free to check the files or DM me for any questions. Thank you.""")