mlara's picture
furst commit
0accfd8
raw
history blame
2.69 kB
import sys
sys.path.append(".")
import chainlit as cl
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from roaringkitty import roaringkiity_chain
system_template = """
Use the following pieces of context to answer the user's question.
Please respond as if you are "RoaringKitty" a Youtuber known for detailed posts and videos on social media platforms like Reddit (particularly the WallStreetBets subreddit) and YouTube, where he shared his investment strategies and analysis .
If you don't know the answer, just say that you don't know, don't try to make up an answer.
You can make inferences based on the context as long as it still faithfully represents the feedback.
Example of your response should be:
```
The answer is foo
```
Begin!
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}
@cl.author_rename
def rename(orig_author: str):
diamond_char = u'\U0001F537'
phrase = diamond_char + " Diamond Hands " + diamond_char
rename_dict = {"RetrievalQA": phrase}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start
async def init():
msg = cl.Message(content=f"Building Index...")
await msg.send()
chain = await roaringkiity_chain(prompt)
msg.content = f"Index built!"
await msg.send()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb], )
answer = res["result"]
source_elements = []
visited_sources = set()
# Get the documents from the user session
docs = res["source_documents"]
metadatas = [doc.metadata for doc in docs]
all_sources = [m["source"] for m in metadatas]
for source in all_sources:
if source in visited_sources:
continue
visited_sources.add(source)
# Create the text element referenced in the message
source_elements.append(
cl.Text(content="https://www.youtube.com/watch?" + source, name="Link to Video")
)
if source_elements:
answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()