Spaces:
Sleeping
Sleeping
Jaspertw177
commited on
Commit
•
7f8ded9
1
Parent(s):
eb0fef7
init
Browse files- .gitignore +3 -0
- app.py +7 -0
- chat.py +121 -0
- pages/Chatbot.py +45 -0
- pages/Chatbot_with_uploaded_docs.py +69 -0
- requirements.txt +10 -0
- utils.py +63 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.pyc
|
3 |
+
.streamlit/
|
app.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.external.langchain import StreamlitCallbackHandler
|
3 |
+
|
4 |
+
st.set_page_config(page_title="ChatBot", page_icon="🤭")
|
5 |
+
|
6 |
+
st.title("CHOOSE FROM THE SIDEBAR")
|
7 |
+
st.sidebar.success("Select a demo above 🐮")
|
chat.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
from langchain.chains import ConversationalRetrievalChain, ConversationChain
|
5 |
+
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
|
6 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
7 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
8 |
+
from langchain.schema import BaseRetriever, Document
|
9 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
+
from langchain_community.vectorstores import DocArrayInMemorySearch
|
11 |
+
from langchain.agents import initialize_agent, AgentType
|
12 |
+
from langchain_community.agent_toolkits.load_tools import load_tools
|
13 |
+
from utils import MEMORY, load_document
|
14 |
+
import streamlit as st
|
15 |
+
|
16 |
+
logging.basicConfig(encoding="utf-8", level=logging.INFO)
|
17 |
+
LOGGER = logging.getLogger()
|
18 |
+
|
19 |
+
def config_retriever(docs: list[Document], use_compression=False, chunk_size=1500):
|
20 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap = 200)
|
21 |
+
splits = text_splitter.split_documents(docs)
|
22 |
+
|
23 |
+
embeddings = AzureOpenAIEmbeddings(
|
24 |
+
api_key=st.secrets['key'],
|
25 |
+
azure_deployment=st.secrets['embedding_name'],
|
26 |
+
openai_api_version=st.secrets['embedding_version'],
|
27 |
+
azure_endpoint=st.secrets['endpoint'],
|
28 |
+
)
|
29 |
+
|
30 |
+
vectorDB = DocArrayInMemorySearch.from_documents(splits, embeddings)
|
31 |
+
retriever = vectorDB.as_retriever(
|
32 |
+
search_type='mmr',
|
33 |
+
search_kwargs={
|
34 |
+
"k": 5,
|
35 |
+
"fetch_k": 7,
|
36 |
+
"include_metadata": True
|
37 |
+
}
|
38 |
+
)
|
39 |
+
if not use_compression:
|
40 |
+
return retriever
|
41 |
+
else:
|
42 |
+
embeddings_filter = EmbeddingsFilter(
|
43 |
+
embeddings=embeddings, similarity_threshold=0.2
|
44 |
+
)
|
45 |
+
return ContextualCompressionRetriever(
|
46 |
+
base_compressor=embeddings_filter,
|
47 |
+
base_retriever=retriever
|
48 |
+
)
|
49 |
+
|
50 |
+
def config_baseretrieval_chain(retriever: BaseRetriever, temperature=0.1):
|
51 |
+
LLM = AzureChatOpenAI(
|
52 |
+
api_key=st.secrets['key'],
|
53 |
+
openai_api_version=st.secrets['chat_version'],
|
54 |
+
azure_deployment=st.secrets['chat_name'],
|
55 |
+
azure_endpoint=st.secrets['endpoint'],
|
56 |
+
temperature=temperature,
|
57 |
+
)
|
58 |
+
|
59 |
+
MEMORY.output_key = 'answer'
|
60 |
+
params = dict(
|
61 |
+
llm=LLM,
|
62 |
+
retriever=retriever,
|
63 |
+
memory=MEMORY,
|
64 |
+
verbose=True
|
65 |
+
)
|
66 |
+
return ConversationalRetrievalChain.from_llm(**params)
|
67 |
+
|
68 |
+
def ddg_search_agent(temperature=0.1):
|
69 |
+
LLM = AzureChatOpenAI(
|
70 |
+
api_key=st.secrets['key'],
|
71 |
+
openai_api_version=st.secrets['chat_version'],
|
72 |
+
azure_deployment=st.secrets['chat_name'],
|
73 |
+
azure_endpoint=st.secrets['endpoint'],
|
74 |
+
temperature=temperature,
|
75 |
+
)
|
76 |
+
|
77 |
+
tools = load_tools(
|
78 |
+
tool_names=['ddg-search'],
|
79 |
+
llm=LLM,
|
80 |
+
model="gpt-4o-mini"
|
81 |
+
)
|
82 |
+
return initialize_agent(
|
83 |
+
tools=tools, llm=LLM, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True
|
84 |
+
)
|
85 |
+
|
86 |
+
def config_retrieval_chain(
|
87 |
+
upload_files,
|
88 |
+
use_compression=False,
|
89 |
+
use_chunksize=1500,
|
90 |
+
use_temperature=0.1,
|
91 |
+
use_zeroshoot=False
|
92 |
+
):
|
93 |
+
docs = []
|
94 |
+
temp_dir = tempfile.TemporaryDirectory()
|
95 |
+
for file in upload_files:
|
96 |
+
temp_filepath = os.path.join(temp_dir.name, file.name)
|
97 |
+
with open(temp_filepath, "wb") as f:
|
98 |
+
f.write(file.getvalue())
|
99 |
+
docs.extend(load_document(temp_filepath))
|
100 |
+
|
101 |
+
retriever = config_retriever(docs=docs, use_compression=use_compression, chunk_size=use_chunksize)
|
102 |
+
chain = config_baseretrieval_chain(retriever=retriever, temperature=use_temperature)
|
103 |
+
if use_zeroshoot:
|
104 |
+
return ddg_search_agent(temperature=use_temperature)
|
105 |
+
else:
|
106 |
+
return chain
|
107 |
+
|
108 |
+
def config_noretrieval_chain(use_temperature=0.1,use_zeroshoot=False):
|
109 |
+
LLM = AzureChatOpenAI(
|
110 |
+
api_key=st.secrets['key'],
|
111 |
+
openai_api_version=st.secrets['chat_version'],
|
112 |
+
azure_deployment=st.secrets['chat_name'],
|
113 |
+
azure_endpoint=st.secrets['endpoint'],
|
114 |
+
temperature=use_temperature,
|
115 |
+
)
|
116 |
+
if use_zeroshoot:
|
117 |
+
return ddg_search_agent(temperature=use_temperature)
|
118 |
+
else:
|
119 |
+
return ConversationChain(llm=LLM)
|
120 |
+
|
121 |
+
|
pages/Chatbot.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import logging
|
3 |
+
from utils import MEMORY, DocumentLoader, check_password
|
4 |
+
from chat import config_noretrieval_chain
|
5 |
+
from streamlit.external.langchain import StreamlitCallbackHandler
|
6 |
+
|
7 |
+
logging.basicConfig(encoding="utf-8", level=logging.INFO)
|
8 |
+
LOGGER = logging.getLogger()
|
9 |
+
|
10 |
+
def main_chat_ui():
|
11 |
+
use_temperature = st.sidebar.slider(
|
12 |
+
'Temperature 🦄',
|
13 |
+
0.0, 1.0, (0.1))
|
14 |
+
use_ddg_search = st.checkbox("Search on DuckDuckGO🦆", value=False)
|
15 |
+
|
16 |
+
CONV_CHAIN = config_noretrieval_chain(
|
17 |
+
use_temperature=use_temperature,
|
18 |
+
use_zeroshoot=use_ddg_search
|
19 |
+
)
|
20 |
+
if st.sidebar.button("Clear History🦭"):
|
21 |
+
MEMORY.chat_memory.clear()
|
22 |
+
if len(MEMORY.chat_memory.messages) == 0:
|
23 |
+
st.chat_message("assistant").markdown("Ask me something🤖")
|
24 |
+
|
25 |
+
avatars = {"human": "user", "ai": "assistant"}
|
26 |
+
if user_query := st.chat_input(placeholder="Say something🐻"):
|
27 |
+
st.chat_message("user").write(user_query)
|
28 |
+
container = st.empty()
|
29 |
+
stream_handler = StreamlitCallbackHandler(container)
|
30 |
+
with st.chat_message("assistant"):
|
31 |
+
if use_ddg_search:
|
32 |
+
response = CONV_CHAIN.invoke(
|
33 |
+
{"input": user_query}, {"callbacks": [stream_handler]}
|
34 |
+
)
|
35 |
+
st.write(response["output"])
|
36 |
+
else:
|
37 |
+
response = CONV_CHAIN.run(user_query)
|
38 |
+
if response:
|
39 |
+
container.markdown(response)
|
40 |
+
|
41 |
+
|
42 |
+
if not check_password():
|
43 |
+
st.stop()
|
44 |
+
st.title("👻START CHAT👻")
|
45 |
+
main_chat_ui()
|
pages/Chatbot_with_uploaded_docs.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import logging
|
3 |
+
from utils import MEMORY, DocumentLoader, check_password
|
4 |
+
from chat import config_retrieval_chain
|
5 |
+
from streamlit.external.langchain import StreamlitCallbackHandler
|
6 |
+
|
7 |
+
logging.basicConfig(encoding="utf-8", level=logging.INFO)
|
8 |
+
LOGGER = logging.getLogger()
|
9 |
+
|
10 |
+
def main_RAG_ui():
|
11 |
+
use_chunk = st.sidebar.slider(
|
12 |
+
'Chunk Size',
|
13 |
+
500, 2000, (1000)
|
14 |
+
)
|
15 |
+
use_temperature = st.sidebar.slider(
|
16 |
+
'Temperature 🦄',
|
17 |
+
0.0, 1.0, (0.1))
|
18 |
+
use_compression = st.checkbox("Compression🛠️(on uploaded document)", value=False)
|
19 |
+
use_ddg_search = st.checkbox("Search on DuckDuckGO🦆(does not use document)", value=False)
|
20 |
+
|
21 |
+
|
22 |
+
CONV_CHAIN = config_retrieval_chain(
|
23 |
+
uploaded_files,
|
24 |
+
use_compression=use_compression,
|
25 |
+
use_chunksize=use_chunk,
|
26 |
+
use_temperature=use_temperature,
|
27 |
+
use_zeroshoot=use_ddg_search
|
28 |
+
)
|
29 |
+
if st.sidebar.button("Clear History🦭"):
|
30 |
+
MEMORY.chat_memory.clear()
|
31 |
+
if len(MEMORY.chat_memory.messages) == 0:
|
32 |
+
st.chat_message("assistant").markdown("Ask me something🤖")
|
33 |
+
avatars = {"human": "user", "ai": "assistant"}
|
34 |
+
|
35 |
+
if user_query := st.chat_input(placeholder="Say something🐻"):
|
36 |
+
st.chat_message("user").write(user_query)
|
37 |
+
container = st.empty()
|
38 |
+
stream_handler = StreamlitCallbackHandler(container)
|
39 |
+
with st.chat_message("assistant"):
|
40 |
+
if use_ddg_search:
|
41 |
+
response = CONV_CHAIN.invoke(
|
42 |
+
{"input": user_query}, {"callbacks": [stream_handler]}
|
43 |
+
)
|
44 |
+
st.write(response["output"])
|
45 |
+
else:
|
46 |
+
params = {
|
47 |
+
"question": user_query,
|
48 |
+
"chat_history": MEMORY.chat_memory.messages,
|
49 |
+
}
|
50 |
+
response = CONV_CHAIN.run(params, callbacks=[stream_handler])
|
51 |
+
if response:
|
52 |
+
container.markdown(response)
|
53 |
+
|
54 |
+
|
55 |
+
if not check_password():
|
56 |
+
st.stop()
|
57 |
+
st.title("👻START CHAT👻")
|
58 |
+
|
59 |
+
uploaded_files = st.sidebar.file_uploader(
|
60 |
+
label="Upload a file🐣",
|
61 |
+
type=list(DocumentLoader.supported_extensions.keys()),
|
62 |
+
accept_multiple_files=True
|
63 |
+
)
|
64 |
+
|
65 |
+
if not uploaded_files:
|
66 |
+
st.info("Upload a file to start🐣")
|
67 |
+
st.stop()
|
68 |
+
|
69 |
+
main_RAG_ui()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
docarray==0.40.0
|
2 |
+
duckduckgo_search==6.2.1
|
3 |
+
langchain==0.2.11
|
4 |
+
langchain-community==0.2.10
|
5 |
+
langchain-core==0.2.23
|
6 |
+
langchain-openai==0.1.17
|
7 |
+
langchain-text-splitters==0.2.2
|
8 |
+
langsmith==0.1.93
|
9 |
+
pypdf==4.3.1
|
10 |
+
streamlit==1.36.0
|
utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pathlib
|
3 |
+
from langchain_community.document_loaders import PyPDFLoader
|
4 |
+
from langchain_community.document_loaders import TextLoader
|
5 |
+
from langchain.memory import ConversationBufferMemory
|
6 |
+
from langchain.schema import Document
|
7 |
+
import hmac
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
def init_memory(key):
|
11 |
+
"""
|
12 |
+
Initialize the memory for contextual conversation.
|
13 |
+
|
14 |
+
We are caching this, so it won't be deleted every time, we restart the server.
|
15 |
+
"""
|
16 |
+
return ConversationBufferMemory(
|
17 |
+
memory_key=key,
|
18 |
+
return_messages=True,
|
19 |
+
output_key='answer'
|
20 |
+
)
|
21 |
+
MEMORY = init_memory('chat_history')
|
22 |
+
|
23 |
+
class DocumentLoaderException(Exception):
|
24 |
+
pass
|
25 |
+
|
26 |
+
class DocumentLoader(object):
|
27 |
+
supported_extensions = {
|
28 |
+
".pdf": PyPDFLoader,
|
29 |
+
".txt": TextLoader
|
30 |
+
}
|
31 |
+
|
32 |
+
def load_document(temp_filepath: str) -> list[Document]:
|
33 |
+
ext = pathlib.Path(temp_filepath).suffix
|
34 |
+
loader = DocumentLoader.supported_extensions.get(ext)
|
35 |
+
if not loader:
|
36 |
+
raise DocumentLoaderException(
|
37 |
+
f"Invalid file extension: <{ext}>"
|
38 |
+
)
|
39 |
+
|
40 |
+
loaded = loader(temp_filepath)
|
41 |
+
docs = loaded.load()
|
42 |
+
logging.info(docs)
|
43 |
+
return docs
|
44 |
+
|
45 |
+
|
46 |
+
def check_password():
|
47 |
+
st.header("")
|
48 |
+
def password_entered():
|
49 |
+
if hmac.compare_digest(st.session_state["password"], st.secrets["adminpassword"]):
|
50 |
+
st.session_state["password_correct"] = True
|
51 |
+
del st.session_state["password"] # Don't store the password.
|
52 |
+
else:
|
53 |
+
st.session_state["password_correct"] = False
|
54 |
+
|
55 |
+
if st.session_state.get("password_correct", False):
|
56 |
+
return True
|
57 |
+
|
58 |
+
st.text_input(
|
59 |
+
"Enter Password 🚀", type="password", on_change=password_entered, key="password"
|
60 |
+
)
|
61 |
+
if "password_correct" in st.session_state:
|
62 |
+
st.error("Password incorrect 😕")
|
63 |
+
return False
|