File size: 2,557 Bytes
b2b64bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72022b6
b2b64bc
 
 
 
 
 
72022b6
b2b64bc
 
 
 
 
 
 
 
 
 
 
72022b6
b2b64bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72022b6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from operator import itemgetter
import chainlit as cl
from langchain.schema.runnable import RunnablePassthrough
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

from utils import ArxivLoader, PineconeIndexer

system_template = """
Use the provided context to answer the user's query.

You may not answer the user's query unless there is specific context in the following text.

If you do not know the answer, or cannot answer, please respond with "I don't know".

Context:
{context}
"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]

prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}

@cl.author_rename
def rename(orig_author: str):
    rename_dict = {"RetrievalQA": "Learning about Nuclear Fission"}
    return rename_dict.get(orig_author, orig_author)

@cl.on_chat_start  # marks a function that will be executed at the start of a user session
async def start_chat():

    msg = cl.Message(content=f"Initializing the Application...")
    await msg.send()

    # load documents from Arxiv
    axloader = ArxivLoader()
    axloader.main()

    # load embedder and the retriever
    pi = PineconeIndexer()
    pi.load_embedder()
    retriever=pi.get_vectorstore().as_retriever()
    print(pi.index.describe_index_stats())

    # build llm
    llm = ChatOpenAI(
        model="gpt-3.5-turbo",
        temperature=0
    )

    msg = cl.Message(content=f"Application is ready !")
    await msg.send()

    cl.user_session.set("llm", llm)
    cl.user_session.set("retriever", retriever)

@cl.on_message  # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):

    llm = cl.user_session.get("llm")
    retriever = cl.user_session.get("retriever")

    retrieval_augmented_qa_chain = (
        {"context": itemgetter("question") | retriever,
        "question": itemgetter("question")
        }
        | RunnablePassthrough.assign(
            context=itemgetter("context")
        )
        | {
            "response": prompt  | llm,
            "context": itemgetter("context"),
        }
    )

    answer = retrieval_augmented_qa_chain.invoke({"question" : message.content})
    
    await cl.Message(content=answer["response"].content).send()