rk68 commited on
Commit
1fff800
·
verified ·
1 Parent(s): ff8b6e7

Upload 2 files

Browse files
Files changed (2) hide show
  1. main_hf.py +329 -0
  2. policy.pdf +0 -0
main_hf.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pandas as pd
3
+ import streamlit as st
4
+ from pinecone import Pinecone
5
+ from llama_index.llms.gemini import Gemini
6
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
7
+ from llama_index.core import (
8
+ StorageContext, VectorStoreIndex, SimpleDirectoryReader,
9
+ get_response_synthesizer, Settings
10
+ )
11
+ from llama_index.core.node_parser import SentenceSplitter
12
+ from llama_index.core.retrievers import (
13
+ VectorIndexRetriever, RouterRetriever
14
+ )
15
+ from llama_index.retrievers.bm25 import BM25Retriever
16
+ from llama_index.core.tools import RetrieverTool
17
+ from llama_index.core.query_engine import (
18
+ RetrieverQueryEngine, FLAREInstructQueryEngine, MultiStepQueryEngine
19
+ )
20
+ from llama_index.core.indices.query.query_transform import (
21
+ StepDecomposeQueryTransform
22
+ )
23
+ from llama_index.llms.groq import Groq
24
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
25
+ from llama_index.llms.azure_openai import AzureOpenAI
26
+ from llama_index.embeddings.openai import OpenAIEmbedding
27
+ from llama_index.readers.file import PyMuPDFReader
28
+ import traceback
29
+ from oauth2client.service_account import ServiceAccountCredentials
30
+ import gspread
31
+ import uuid
32
+ from dotenv import load_dotenv
33
+ import os
34
+
35
+ # Load environment variables
36
+ load_dotenv()
37
+
38
+ # Configure logging
39
+ logging.basicConfig(level=logging.INFO)
40
+
41
+ # Google Sheets setup
42
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
43
+ creds_dict = {
44
+ "type": os.getenv("type"),
45
+ "project_id": os.getenv("project_id"),
46
+ "private_key_id": os.getenv("private_key_id"),
47
+ "private_key": os.getenv("private_key").replace('\\n', '\n'),
48
+ "client_email": os.getenv("client_email"),
49
+ "client_id": os.getenv("client_id"),
50
+ "auth_uri": os.getenv("auth_uri"),
51
+ "token_uri": os.getenv("token_uri"),
52
+ "auth_provider_x509_cert_url": os.getenv("auth_provider_x509_cert_url"),
53
+ "client_x509_cert_url": os.getenv("client_x509_cert_url")
54
+ }
55
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
56
+ client = gspread.authorize(creds)
57
+ sheet = client.open("RAG").sheet1
58
+
59
+ # Fixed variables
60
+ AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME")
61
+ AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")
62
+ AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
63
+
64
+ # Global variables for lazy loading
65
+ llm = None
66
+ pinecone_index = None
67
+
68
+ def log_and_exit(message):
69
+ logging.error(message)
70
+ raise SystemExit(message)
71
+
72
+ def initialize_apis(api, model, pinecone_api_key, groq_api_key, azure_api_key):
73
+ global llm, pinecone_index
74
+ try:
75
+ if llm is None:
76
+ llm = initialize_llm(api, model, groq_api_key, azure_api_key)
77
+ if pinecone_index is None:
78
+ pinecone_client = Pinecone(pinecone_api_key)
79
+ pinecone_index = pinecone_client.Index("demo")
80
+ logging.info("Initialized LLM and Pinecone.")
81
+ except Exception as e:
82
+ log_and_exit(f"Error initializing APIs: {e}")
83
+
84
+ def initialize_llm(api, model, groq_api_key, azure_api_key):
85
+ if api == 'groq':
86
+ model_mappings = {
87
+ 'mixtral-8x7b': "mixtral-8x7b-32768",
88
+ 'llama3-8b': "llama3-8b-8192",
89
+ 'llama3-70b': "llama3-70b-8192",
90
+ 'gemma-7b': "gemma-7b-it"
91
+ }
92
+ return Groq(model=model_mappings[model], api_key=groq_api_key)
93
+ elif api == 'azure':
94
+ if model == 'gpt35':
95
+ return AzureOpenAI(
96
+ deployment_name=AZURE_DEPLOYMENT_NAME,
97
+ temperature=0,
98
+ api_key=azure_api_key,
99
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
100
+ api_version=AZURE_API_VERSION
101
+ )
102
+
103
+ def load_pdf_data(chunk_size):
104
+ PDF_FILE_PATH = "policy.pdf"
105
+ reader = PyMuPDFReader()
106
+ file_extractor = {".pdf": reader}
107
+ documents = SimpleDirectoryReader(input_files=[PDF_FILE_PATH], file_extractor=file_extractor).load_data()
108
+ return documents
109
+
110
+ def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25", chunk_size=512):
111
+ global llm, pinecone_index
112
+ try:
113
+ embed_model = select_embedding_model(embedding_model_type, embedding_model)
114
+
115
+ Settings.llm = llm
116
+ Settings.embed_model = embed_model
117
+ Settings.chunk_size = chunk_size
118
+
119
+ if retriever_method in ["BM25", "BM25+Vector"]:
120
+ nodes = create_bm25_nodes(documents, chunk_size)
121
+ logging.info("Created BM25 nodes from documents.")
122
+ if retriever_method == "BM25+Vector":
123
+ vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
124
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
125
+ index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
126
+ logging.info("Created index for BM25+Vector from documents.")
127
+ return index, nodes
128
+ return None, nodes
129
+ else:
130
+ vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
131
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
132
+ index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
133
+ logging.info("Created index from documents.")
134
+ return index, None
135
+ except Exception as e:
136
+ log_and_exit(f"Error creating index: {e}")
137
+
138
+ def select_embedding_model(embedding_model_type, embedding_model):
139
+ if embedding_model_type == "HF":
140
+ return HuggingFaceEmbedding(model_name=embedding_model)
141
+ elif embedding_model_type == "OAI":
142
+ return OpenAIEmbedding() # Implement OAI Embedding if needed
143
+
144
+ def create_bm25_nodes(documents, chunk_size):
145
+ splitter = SentenceSplitter(chunk_size=chunk_size)
146
+ nodes = splitter.get_nodes_from_documents(documents)
147
+ return nodes
148
+
149
+ def select_retriever(index, nodes, retriever_method, top_k):
150
+ logging.info(f"Selecting retriever with method: {retriever_method}")
151
+ if nodes is not None:
152
+ logging.info(f"Available document IDs: {list(range(len(nodes)))}")
153
+ else:
154
+ logging.warning("Nodes are None")
155
+
156
+ if retriever_method == 'BM25':
157
+ return BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)
158
+ elif retriever_method == "BM25+Vector":
159
+ if index is None:
160
+ log_and_exit("Index must be initialized when using BM25+Vector retriever method.")
161
+
162
+ bm25_retriever = RetrieverTool.from_defaults(
163
+ retriever=BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k),
164
+ description="BM25 Retriever"
165
+ )
166
+
167
+ vector_retriever = RetrieverTool.from_defaults(
168
+ retriever=VectorIndexRetriever(index=index),
169
+ description="Vector Retriever"
170
+ )
171
+
172
+ router_retriever = RouterRetriever.from_defaults(
173
+ retriever_tools=[bm25_retriever, vector_retriever],
174
+ llm=llm,
175
+ select_multi=True
176
+ )
177
+ return router_retriever
178
+ elif retriever_method == "Vector Search":
179
+ if index is None:
180
+ log_and_exit("Index must be initialized when using Vector Search retriever method.")
181
+ return VectorIndexRetriever(index=index, similarity_top_k=top_k)
182
+ else:
183
+ log_and_exit(f"Unsupported retriever method: {retriever_method}")
184
+
185
+ def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None, top_k=2):
186
+ global llm
187
+ try:
188
+ logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
189
+ retriever = select_retriever(index, nodes, retriever_method, top_k)
190
+
191
+ if retriever is None:
192
+ log_and_exit("Failed to create retriever. Index or nodes might be None.")
193
+
194
+ response_synthesizer = get_response_synthesizer(response_mode=response_mode)
195
+ index_query_engine = index.as_query_engine(similarity_top_k=top_k) if index else None
196
+
197
+ if query_engine_method == "FLARE":
198
+ query_engine = FLAREInstructQueryEngine(
199
+ query_engine=index_query_engine,
200
+ max_iterations=4,
201
+ verbose=False
202
+ )
203
+ elif query_engine_method == "MS":
204
+ query_engine = MultiStepQueryEngine(
205
+ query_engine=index_query_engine,
206
+ query_transform=StepDecomposeQueryTransform(llm=llm, verbose=False),
207
+ index_summary="Used to answer questions about the regulation"
208
+ )
209
+ else:
210
+ query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
211
+
212
+ if query_engine is None:
213
+ log_and_exit("Failed to create query engine.")
214
+
215
+ return query_engine
216
+ except Exception as e:
217
+ logging.error(f"Error setting up query engine: {e}")
218
+ traceback.print_exc()
219
+ log_and_exit(f"Error setting up query engine: {e}")
220
+
221
+ def log_to_google_sheets(data):
222
+ try:
223
+ sheet.append_row(data)
224
+ logging.info("Logged data to Google Sheets.")
225
+ except Exception as e:
226
+ logging.error(f"Error logging data to Google Sheets: {e}")
227
+
228
+ def update_google_sheets(question_id, feedback=None, detailed_feedback=None):
229
+ try:
230
+ existing_data = sheet.get_all_values()
231
+ headers = existing_data[0]
232
+ for i, row in enumerate(existing_data):
233
+ if row[0] == question_id:
234
+ if feedback is not None:
235
+ sheet.update_cell(i+1, headers.index("Feedback") + 1, feedback)
236
+ if detailed_feedback is not None:
237
+ sheet.update_cell(i+1, headers.index("Detailed Feedback") + 1, detailed_feedback)
238
+ logging.info("Updated data in Google Sheets.")
239
+ return
240
+ except Exception as e:
241
+ logging.error(f"Error updating data in Google Sheets: {e}")
242
+
243
+ def run_streamlit_app():
244
+ if 'query_engine' not in st.session_state:
245
+ st.session_state.query_engine = None
246
+
247
+ st.title("RAG Chat Application")
248
+
249
+ col1, col2 = st.columns(2)
250
+
251
+ with col1:
252
+ pinecone_api_key = st.text_input("Pinecone API Key")
253
+ parse_api_key = st.text_input("Parse API Key")
254
+ azure_api_key = st.text_input("Azure API Key")
255
+ groq_api_key = st.text_input("Groq API Key")
256
+
257
+ with col2:
258
+ selected_api = st.selectbox("Select API", ["azure", "groq"])
259
+ selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"])
260
+ embedding_model_type = "HF"
261
+ embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"])
262
+ retriever_method = st.selectbox("Select Retriever Method", ["Vector Search", "BM25", "BM25+Vector"])
263
+
264
+ col3, col4 = st.columns(2)
265
+ with col3:
266
+ chunk_size = st.selectbox("Select Chunk Size", [128, 256, 512, 1024], index=2)
267
+ with col4:
268
+ top_k = st.selectbox("Select Top K", [1, 2, 3, 5, 6], index=1)
269
+
270
+ if st.button("Initialize"):
271
+ initialize_apis(selected_api, selected_model, pinecone_api_key, groq_api_key, azure_api_key)
272
+ documents = load_pdf_data(chunk_size)
273
+ index, nodes = create_index(documents, embedding_model_type=embedding_model_type, embedding_model=embedding_model, retriever_method=retriever_method, chunk_size=chunk_size)
274
+ st.session_state.query_engine = setup_query_engine(index, response_mode="compact", nodes=nodes, query_engine_method=None, retriever_method=retriever_method, top_k=top_k)
275
+ st.success("Initialization complete.")
276
+
277
+ if 'chat_history' not in st.session_state:
278
+ st.session_state.chat_history = []
279
+
280
+ for chat_index, chat in enumerate(st.session_state.chat_history):
281
+ with st.chat_message("user"):
282
+ st.markdown(chat['user'])
283
+ with st.chat_message("bot"):
284
+ st.markdown("### Retrieved Contexts")
285
+ for node in chat.get('contexts', []):
286
+ st.markdown(
287
+ f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0; font-size:small;'>{node.text}</div>",
288
+ unsafe_allow_html=True
289
+ )
290
+ st.markdown("### Answer")
291
+ st.markdown(chat['response'])
292
+ col1, col2, col3 = st.columns([1, 1, 3])
293
+ with col1:
294
+ if st.button("👍", key=f"up_{chat_index}"):
295
+ if 'feedback' not in chat:
296
+ chat['feedback'] = 1
297
+ st.session_state.chat_history[chat_index] = chat
298
+ update_google_sheets(chat['id'], feedback=1)
299
+ with col2:
300
+ if st.button("👎", key=f"down_{chat_index}"):
301
+ if 'feedback' not in chat:
302
+ chat['feedback'] = -1
303
+ st.session_state.chat_history[chat_index] = chat
304
+ update_google_sheets(chat['id'], feedback=-1)
305
+ with col3:
306
+ feedback = st.text_area("How was the response? Does it match the context? Does it answer the question fully?", key=f"textarea_{chat_index}")
307
+ if st.button("Submit Feedback", key=f"submit_{chat_index}"):
308
+ chat['detailed_feedback'] = feedback
309
+ st.session_state.chat_history[chat_index] = chat
310
+ update_google_sheets(chat['id'], detailed_feedback=feedback)
311
+
312
+ if question := st.chat_input("Enter your question"):
313
+ if st.session_state.query_engine:
314
+ with st.spinner('Generating response...'):
315
+ response = st.session_state.query_engine.query(question)
316
+ logging.info(f"Generated response: {response.response}")
317
+ logging.info(f"Retrieved contexts: {[node.text for node in response.source_nodes]}")
318
+ question_id = str(uuid.uuid4())
319
+ st.session_state.chat_history.append({'id': question_id, 'user': question, 'response': response.response, 'contexts': response.source_nodes, 'feedback': 0, 'detailed_feedback': ''})
320
+
321
+ # Log initial query and response to Google Sheets without feedback
322
+ log_to_google_sheets([question_id, question, response.response, selected_api, selected_model, embedding_model, retriever_method, chunk_size, top_k, 0, ""])
323
+
324
+ st.rerun()
325
+ else:
326
+ st.error("Query engine is not initialized. Please initialize it first.")
327
+
328
+ if __name__ == "__main__":
329
+ run_streamlit_app()
policy.pdf ADDED
Binary file (463 kB). View file