itismouad's picture
fixing the retriever from pinecone
72022b6
from operator import itemgetter
import chainlit as cl
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from utils import ArxivLoader, PineconeIndexer
system_template = """
Use the provided context to answer the user's query.
You may not answer the user's query unless there is specific context in the following text.
If you do not know the answer, or cannot answer, please respond with "I don't know".
Context:
{context}
"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"RetrievalQA": "Learning about Nuclear Fission"}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
msg = cl.Message(content=f"Initializing the Application...")
await msg.send()
# load documents from Arxiv
axloader = ArxivLoader()
axloader.main()
# load embedder and the retriever
pi = PineconeIndexer()
pi.load_embedder()
retriever=pi.get_vectorstore().as_retriever()
print(pi.index.describe_index_stats())
# build llm
llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0
)
msg = cl.Message(content=f"Application is ready !")
await msg.send()
cl.user_session.set("llm", llm)
cl.user_session.set("retriever", retriever)
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
llm = cl.user_session.get("llm")
retriever = cl.user_session.get("retriever")
retrieval_augmented_qa_chain = (
{"context": itemgetter("question") | retriever,
"question": itemgetter("question")
}
| RunnablePassthrough.assign(
context=itemgetter("context")
)
| {
"response": prompt | llm,
"context": itemgetter("context"),
}
)
answer = retrieval_augmented_qa_chain.invoke({"question" : message.content})
await cl.Message(content=answer["response"].content).send()