|
from operator import itemgetter |
|
import chainlit as cl |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.prompts.chat import ( |
|
ChatPromptTemplate, |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
) |
|
|
|
from utils import ArxivLoader, PineconeIndexer |
|
|
|
system_template = """ |
|
Use the provided context to answer the user's query. |
|
|
|
You may not answer the user's query unless there is specific context in the following text. |
|
|
|
If you do not know the answer, or cannot answer, please respond with "I don't know". |
|
|
|
Context: |
|
{context} |
|
""" |
|
|
|
messages = [ |
|
SystemMessagePromptTemplate.from_template(system_template), |
|
HumanMessagePromptTemplate.from_template("{question}"), |
|
] |
|
|
|
prompt = ChatPromptTemplate(messages=messages) |
|
chain_type_kwargs = {"prompt": prompt} |
|
|
|
@cl.author_rename |
|
def rename(orig_author: str): |
|
rename_dict = {"RetrievalQA": "Learning about Nuclear Fission"} |
|
return rename_dict.get(orig_author, orig_author) |
|
|
|
@cl.on_chat_start |
|
async def start_chat(): |
|
|
|
msg = cl.Message(content=f"Initializing the Application...") |
|
await msg.send() |
|
|
|
|
|
axloader = ArxivLoader() |
|
axloader.main() |
|
|
|
|
|
pi = PineconeIndexer() |
|
pi.load_embedder() |
|
retriever=pi.get_vectorstore().as_retriever() |
|
print(pi.index.describe_index_stats()) |
|
|
|
|
|
llm = ChatOpenAI( |
|
model="gpt-3.5-turbo", |
|
temperature=0 |
|
) |
|
|
|
msg = cl.Message(content=f"Application is ready !") |
|
await msg.send() |
|
|
|
cl.user_session.set("llm", llm) |
|
cl.user_session.set("retriever", retriever) |
|
|
|
@cl.on_message |
|
async def main(message: cl.Message): |
|
|
|
llm = cl.user_session.get("llm") |
|
retriever = cl.user_session.get("retriever") |
|
|
|
retrieval_augmented_qa_chain = ( |
|
{"context": itemgetter("question") | retriever, |
|
"question": itemgetter("question") |
|
} |
|
| RunnablePassthrough.assign( |
|
context=itemgetter("context") |
|
) |
|
| { |
|
"response": prompt | llm, |
|
"context": itemgetter("context"), |
|
} |
|
) |
|
|
|
answer = retrieval_augmented_qa_chain.invoke({"question" : message.content}) |
|
|
|
await cl.Message(content=answer["response"].content).send() |
|
|
|
|
|
|
|
|
|
|
|
|