Ritesh-hf commited on
Commit
57d62f7
·
1 Parent(s): 53ab5a4

updare index and reranker

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +26 -11
  3. mbzuai-policies.json +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ *.ipynb
app.py CHANGED
@@ -4,6 +4,9 @@ monkey.patch_all()
4
  import nltk
5
  nltk.download('punkt_tab')
6
 
 
 
 
7
  import os
8
  from dotenv import load_dotenv
9
  import asyncio
@@ -20,10 +23,10 @@ from pinecone import Pinecone
20
  from pinecone_text.sparse import BM25Encoder
21
  from langchain_huggingface import HuggingFaceEmbeddings
22
  from langchain_community.retrievers import PineconeHybridSearchRetriever
23
- from langchain_groq import ChatGroq
24
  from langchain.retrievers import ContextualCompressionRetriever
25
- from langchain.retrievers.document_compressors import FlashrankRerank
26
  from langchain_community.chat_models import ChatPerplexity
 
 
27
 
28
  # Load environment variables
29
  load_dotenv(".env")
@@ -62,7 +65,7 @@ def initialize_pinecone(index_name: str):
62
  ##################################################
63
 
64
  # Initialize Pinecone index and BM25 encoder
65
- pinecone_index = initialize_pinecone("updated-mbzuai-policies")
66
  bm25 = BM25Encoder().load("./new_mbzuai-policies.json")
67
 
68
  ##################################################
@@ -77,7 +80,8 @@ retriever = PineconeHybridSearchRetriever(
77
  sparse_encoder=bm25,
78
  index=pinecone_index,
79
  top_k=20,
80
- alpha=0.5
 
81
  )
82
 
83
  # Initialize LLM
@@ -86,7 +90,11 @@ llm = ChatPerplexity(temperature=0, pplx_api_key=GROQ_API_KEY, model="llama-3.1-
86
 
87
 
88
  # Initialize Reranker
89
- compressor = FlashrankRerank()
 
 
 
 
90
  compression_retriever = ContextualCompressionRetriever(
91
  base_compressor=compressor, base_retriever=retriever
92
  )
@@ -191,14 +199,21 @@ def handle_message(data):
191
  else:
192
  language = "Arabic"
193
  session_id = data.get('session_id', SESSION_ID_DEFAULT)
194
- chain = conversational_rag_chain.pick("answer")
 
 
 
 
 
 
 
 
 
 
195
 
196
  try:
197
- for chunk in chain.stream(
198
- {"input": question, 'language': language},
199
- config={"configurable": {"session_id": session_id}},
200
- ):
201
- emit('response', chunk, room=request.sid)
202
  except Exception as e:
203
  print(f"Error during message handling: {e}")
204
  emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
 
4
  import nltk
5
  nltk.download('punkt_tab')
6
 
7
+ import nltk
8
+ nltk.download('punkt_tab')
9
+
10
  import os
11
  from dotenv import load_dotenv
12
  import asyncio
 
23
  from pinecone_text.sparse import BM25Encoder
24
  from langchain_huggingface import HuggingFaceEmbeddings
25
  from langchain_community.retrievers import PineconeHybridSearchRetriever
 
26
  from langchain.retrievers import ContextualCompressionRetriever
 
27
  from langchain_community.chat_models import ChatPerplexity
28
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
29
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
30
 
31
  # Load environment variables
32
  load_dotenv(".env")
 
65
  ##################################################
66
 
67
  # Initialize Pinecone index and BM25 encoder
68
+ pinecone_index = initialize_pinecone("updated-mbzuai-policies-17112024")
69
  bm25 = BM25Encoder().load("./new_mbzuai-policies.json")
70
 
71
  ##################################################
 
80
  sparse_encoder=bm25,
81
  index=pinecone_index,
82
  top_k=20,
83
+ alpha=0.5,
84
+
85
  )
86
 
87
  # Initialize LLM
 
90
 
91
 
92
  # Initialize Reranker
93
+ # compressor = FlashrankRerank()
94
+ model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
95
+ compressor = CrossEncoderReranker(model=model, top_n=20)
96
+
97
+
98
  compression_retriever = ContextualCompressionRetriever(
99
  base_compressor=compressor, base_retriever=retriever
100
  )
 
199
  else:
200
  language = "Arabic"
201
  session_id = data.get('session_id', SESSION_ID_DEFAULT)
202
+ # chain = conversational_rag_chain.pick("answer")
203
+
204
+ # try:
205
+ # for chunk in conversational_rag_chain.stream(
206
+ # {"input": question, 'language': language},
207
+ # config={"configurable": {"session_id": session_id}},
208
+ # ):
209
+ # emit('response', chunk, room=request.sid)
210
+ # except Exception as e:
211
+ # print(f"Error during message handling: {e}")
212
+ # emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
213
 
214
  try:
215
+ response = conversational_rag_chain.invoke({"input": question, 'language': language}, config={"configurable": {"session_id": session_id}})
216
+ emit('response', response, room=request.sid)
 
 
 
217
  except Exception as e:
218
  print(f"Error during message handling: {e}")
219
  emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
mbzuai-policies.json CHANGED
The diff for this file is too large to render. See raw diff