File size: 4,290 Bytes
9253faf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85824c0
 
 
 
9253faf
 
85824c0
9253faf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85824c0
 
 
9253faf
 
 
 
 
 
 
 
9e54bf7
9253faf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import CacheBackedEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.storage import LocalFileStore
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
import chainlit as cl

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)

# Please respond as if you were Ken from the movie Barbie. Ken is a well-meaning but naive character who loves to Beach. He talks like a typical Californian Beach Bro, but he doesn't use the word "Dude" so much.
# If you don't know the answer, just say that you don't know, don't try to make up an answer.
# You can make inferences based on the context as long as it still faithfully represents the feedback.

system_template = """
Use the following pieces of context to answer the user's question.
Please respond as if you are "RoaringKitty" a Youtuber known for detailed posts and videos on social media platforms like Reddit (particularly the WallStreetBets subreddit) and YouTube, where he shared his investment strategies and analysis . 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
You can make inferences based on the context as long as it still faithfully represents the feedback.
Example of your response should be:

```
The answer is foo
```

Begin!
----------------
{context}"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}

@cl.author_rename
def rename(orig_author: str):
    diamond_char = u'\U0001F537'
    phrase = diamond_char + " Diamond Hands " + diamond_char
    rename_dict = {"RetrievalQA": phrase}
    return rename_dict.get(orig_author, orig_author)

@cl.on_chat_start
async def init():
    msg = cl.Message(content=f"Building Index...")
    await msg.send()

    # build FAISS index from csv
    loader = CSVLoader(file_path="./data/roaringkitty.csv", source_column="Link")
    data = loader.load()
    documents = text_splitter.transform_documents(data)
    store = LocalFileStore("./cache/")
    core_embeddings_model = OpenAIEmbeddings()
    embedder = CacheBackedEmbeddings.from_bytes_store(
        core_embeddings_model, store, namespace=core_embeddings_model.model
    )
    # make async docsearch
    docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)

    chain = RetrievalQA.from_chain_type(
        ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
        chain_type="stuff",
        return_source_documents=True,
        retriever=docsearch.as_retriever(),
        chain_type_kwargs = {"prompt": prompt}
    )

    msg.content = f"Index built!"
    await msg.send()

    cl.user_session.set("chain", chain)


@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    cb.answer_reached = True
    res = await chain.acall(message, callbacks=[cb], )

    answer = res["result"]
    source_elements = []
    visited_sources = set()

    # Get the documents from the user session
    docs = res["source_documents"]
    metadatas = [doc.metadata for doc in docs]
    all_sources = [m["source"] for m in metadatas]

    for source in all_sources:
        if source in visited_sources:
            continue
        visited_sources.add(source)
        # Create the text element referenced in the message
        source_elements.append(
            cl.Text(content="https://www.youtube.com/watch?" + source, name="Link to Video")
        )

    if source_elements:
        answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
    else:
        answer += "\nNo sources found"

    await cl.Message(content=answer, elements=source_elements).send()