Spaces:
Paused
Paused
Commit
·
c102038
1
Parent(s):
c3f3537
feat: added the ability to log the history to DynamoDB
Browse files- EurLexChat.py +56 -2
- app.py +4 -38
- chat_utils.py +5 -2
- config.py +35 -0
- config.yaml +4 -6
- requirements.txt +2 -1
EurLexChat.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from langchain_community.vectorstores import Qdrant
|
2 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
3 |
from langchain_core.runnables.base import RunnableLambda
|
@@ -6,7 +7,7 @@ from langchain_core.tools import StructuredTool
|
|
6 |
from langchain_core.utils.function_calling import convert_to_openai_tool
|
7 |
from langchain_core.messages import AIMessage
|
8 |
from typing import List
|
9 |
-
from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer
|
10 |
from langchain_core.documents.base import Document
|
11 |
|
12 |
|
@@ -59,7 +60,7 @@ class EurLexChat:
|
|
59 |
input_messages_key="question",
|
60 |
history_messages_key="history",
|
61 |
)
|
62 |
-
|
63 |
self.relevant_documents_pipeline = ( self.retriever | self._parse_documents )
|
64 |
|
65 |
|
@@ -96,6 +97,14 @@ class EurLexChat:
|
|
96 |
if self.config["chatDB"]["class"] == 'FileChatMessageHistory':
|
97 |
file_path = f"{kwargs['output_path']}/{session_id}.json"
|
98 |
return self.chatDB_class(file_path=file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
else:
|
100 |
return self.chatDB_class(session_id=session_id, **kwargs)
|
101 |
|
@@ -270,3 +279,48 @@ class EurLexChat:
|
|
270 |
return Answer(answer=result.answer)
|
271 |
return Answer(answer=result.content)
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
from langchain_community.vectorstores import Qdrant
|
3 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
4 |
from langchain_core.runnables.base import RunnableLambda
|
|
|
7 |
from langchain_core.utils.function_calling import convert_to_openai_tool
|
8 |
from langchain_core.messages import AIMessage
|
9 |
from typing import List
|
10 |
+
from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer, get_vectorDB_module
|
11 |
from langchain_core.documents.base import Document
|
12 |
|
13 |
|
|
|
60 |
input_messages_key="question",
|
61 |
history_messages_key="history",
|
62 |
)
|
63 |
+
|
64 |
self.relevant_documents_pipeline = ( self.retriever | self._parse_documents )
|
65 |
|
66 |
|
|
|
97 |
if self.config["chatDB"]["class"] == 'FileChatMessageHistory':
|
98 |
file_path = f"{kwargs['output_path']}/{session_id}.json"
|
99 |
return self.chatDB_class(file_path=file_path)
|
100 |
+
elif self.config["chatDB"]["class"] == 'DynamoDBChatMessageHistory':
|
101 |
+
table_name = kwargs["table_name"]
|
102 |
+
session = boto3.Session(aws_access_key_id=kwargs["aws_access_key_id"],
|
103 |
+
aws_secret_access_key=kwargs["aws_secret_access_key"],
|
104 |
+
region_name='eu-west-1')
|
105 |
+
return self.chatDB_class(session_id=session_id,
|
106 |
+
table_name=table_name,
|
107 |
+
boto3_session=session)
|
108 |
else:
|
109 |
return self.chatDB_class(session_id=session_id, **kwargs)
|
110 |
|
|
|
279 |
return Answer(answer=result.answer)
|
280 |
return Answer(answer=result.content)
|
281 |
|
282 |
+
|
283 |
+
class EurLexChatAkn(EurLexChat):
|
284 |
+
def _parse_documents(self, docs: List[Document]) -> List[dict]:
|
285 |
+
"""
|
286 |
+
Parse a list of documents into a standardized format.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
docs (List[Document]): A list of documents to parse.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
List[dict]: A list of dictionaries, each containing parsed information from the input documents.
|
293 |
+
"""
|
294 |
+
|
295 |
+
parsed_documents = []
|
296 |
+
|
297 |
+
for doc in docs:
|
298 |
+
parsed_documents.append({
|
299 |
+
'text': doc.page_content,
|
300 |
+
'source': doc.metadata["uri"],
|
301 |
+
'_id': doc.metadata["uri"] + doc.metadata["article_id"]
|
302 |
+
})
|
303 |
+
return parsed_documents
|
304 |
+
|
305 |
+
def get_relevant_docs(self, question: str, eurovoc: str = None) -> List[dict]:
|
306 |
+
"""
|
307 |
+
Retrieve relevant documents based on a given question.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
question (str): The question for which relevant documents are retrieved.
|
311 |
+
eurovoc (str): The Eurovoc to be used as filter
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
List[dict]: A list of relevant documents.
|
315 |
+
"""
|
316 |
+
if eurovoc:
|
317 |
+
retriever = get_vectorDB_module(
|
318 |
+
self.config['vectorDB'], self.embedder, metadata={'filter': {'eurovoc': ''}}
|
319 |
+
)
|
320 |
+
relevant_documents_pipeline_with_filter = (retriever | self._parse_documents)
|
321 |
+
docs = relevant_documents_pipeline_with_filter.invoke(
|
322 |
+
question
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
docs = self.relevant_documents_pipeline.invoke(question)
|
326 |
+
return docs
|
app.py
CHANGED
@@ -1,18 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
from EurLexChat import EurLexChat
|
3 |
-
import yaml
|
4 |
import random
|
5 |
import string
|
6 |
-
import
|
7 |
-
import os
|
8 |
-
|
9 |
-
openai_org_key = os.getenv("OPENAI_ORG_KEY")
|
10 |
-
openai_key = os.getenv("OPENAI_KEY")
|
11 |
-
ui_pwd = os.getenv("pwd")
|
12 |
-
ui_user = os.getenv("user")
|
13 |
-
qdrant_url=os.getenv("url")
|
14 |
-
qdrant_key=os.getenv("qdrant_key")
|
15 |
-
|
16 |
|
17 |
def generate_random_string(length):
|
18 |
# Generate a random string of the specified length
|
@@ -25,32 +15,8 @@ class Documents():
|
|
25 |
def __init__(self) -> None:
|
26 |
self.documents = []
|
27 |
|
28 |
-
parser = argparse.ArgumentParser(description="Chat-eur-lex ui")
|
29 |
-
|
30 |
-
parser.add_argument('--config_path',
|
31 |
-
dest='config_path',
|
32 |
-
metavar='config_path',
|
33 |
-
type=str,
|
34 |
-
help='The path to the config file that contains all the settings for the chat engine' ,
|
35 |
-
default='config.yaml')
|
36 |
-
args = parser.parse_args()
|
37 |
-
|
38 |
-
# Read config file
|
39 |
-
with open(args.config_path, 'r') as file:
|
40 |
-
config = yaml.safe_load(file)
|
41 |
-
|
42 |
-
config["embeddings"]["kwargs"]["openai_api_key"] = openai_key
|
43 |
-
config["embeddings"]["kwargs"]["openai_organization"] = openai_org_key
|
44 |
-
config["llm"]["kwargs"]["openai_api_key"] = openai_key
|
45 |
-
config["llm"]["kwargs"]["openai_organization"] = openai_org_key
|
46 |
-
config["vectorDB"]["kwargs"]["url"] = qdrant_url
|
47 |
-
config["vectorDB"]["kwargs"]["api_key"] = qdrant_key
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
chat = EurLexChat(config=
|
54 |
docs = Documents()
|
55 |
|
56 |
|
@@ -113,7 +79,7 @@ with block:
|
|
113 |
|
114 |
with gr.Column(scale=1, visible=False) as col:
|
115 |
gr.Markdown("""<h3><center>Context documents</center></h3>""")
|
116 |
-
for i in range(
|
117 |
with gr.Accordion(label="", elem_id=f'accordion_{i}', open=False) as acc:
|
118 |
list_texts.append(gr.Textbox("", elem_id=f'text_{i}', show_label=False, lines=10))
|
119 |
btn = gr.Button(f"Remove document")
|
@@ -141,4 +107,4 @@ with block:
|
|
141 |
for i, b in enumerate(delete_buttons):
|
142 |
b.click(remove_doc, inputs=states[i], outputs=[*accordions, *list_texts])
|
143 |
|
144 |
-
block.launch(debug=True, auth=(
|
|
|
1 |
import gradio as gr
|
2 |
from EurLexChat import EurLexChat
|
|
|
3 |
import random
|
4 |
import string
|
5 |
+
from config import CONFIG, UI_USER, UI_PWD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def generate_random_string(length):
|
8 |
# Generate a random string of the specified length
|
|
|
15 |
def __init__(self) -> None:
|
16 |
self.documents = []
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
chat = EurLexChat(config=CONFIG)
|
20 |
docs = Documents()
|
21 |
|
22 |
|
|
|
79 |
|
80 |
with gr.Column(scale=1, visible=False) as col:
|
81 |
gr.Markdown("""<h3><center>Context documents</center></h3>""")
|
82 |
+
for i in range(CONFIG['vectorDB']['retriever_args']['search_kwargs']['k']):
|
83 |
with gr.Accordion(label="", elem_id=f'accordion_{i}', open=False) as acc:
|
84 |
list_texts.append(gr.Textbox("", elem_id=f'text_{i}', show_label=False, lines=10))
|
85 |
btn = gr.Button(f"Remove document")
|
|
|
107 |
for i, b in enumerate(delete_buttons):
|
108 |
b.click(remove_doc, inputs=states[i], outputs=[*accordions, *list_texts])
|
109 |
|
110 |
+
block.launch(debug=True, auth=(UI_USER, UI_PWD))
|
chat_utils.py
CHANGED
@@ -64,7 +64,7 @@ def get_init_modules(config):
|
|
64 |
return embedder, llm, chatDB_class, retriever
|
65 |
|
66 |
|
67 |
-
def get_vectorDB_module(db_config, embedder):
|
68 |
mod_chat = __import__("langchain_community.vectorstores",
|
69 |
fromlist=[db_config["class"]])
|
70 |
vectorDB_class = getattr(mod_chat, db_config["class"])
|
@@ -85,10 +85,13 @@ def get_vectorDB_module(db_config, embedder):
|
|
85 |
|
86 |
client = QdrantClient(**client_kwargs)
|
87 |
|
|
|
|
|
88 |
retriever = vectorDB_class(
|
89 |
client, embeddings=embedder, **db_kwargs).as_retriever(
|
90 |
search_type=db_config["retriever_args"]["search_type"],
|
91 |
-
search_kwargs=db_config["retriever_args"]["search_kwargs"]
|
|
|
92 |
)
|
93 |
|
94 |
else:
|
|
|
64 |
return embedder, llm, chatDB_class, retriever
|
65 |
|
66 |
|
67 |
+
def get_vectorDB_module(db_config, embedder, metadata=None):
|
68 |
mod_chat = __import__("langchain_community.vectorstores",
|
69 |
fromlist=[db_config["class"]])
|
70 |
vectorDB_class = getattr(mod_chat, db_config["class"])
|
|
|
85 |
|
86 |
client = QdrantClient(**client_kwargs)
|
87 |
|
88 |
+
if metadata is None:
|
89 |
+
metadata = {}
|
90 |
retriever = vectorDB_class(
|
91 |
client, embeddings=embedder, **db_kwargs).as_retriever(
|
92 |
search_type=db_config["retriever_args"]["search_type"],
|
93 |
+
search_kwargs={**db_config["retriever_args"]["search_kwargs"], **metadata},
|
94 |
+
filter=metadata
|
95 |
)
|
96 |
|
97 |
else:
|
config.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
|
4 |
+
|
5 |
+
# Read config file
|
6 |
+
if os.path.exists('config.yaml'):
|
7 |
+
with open('config.yaml', 'r') as file:
|
8 |
+
CONFIG = yaml.safe_load(file)
|
9 |
+
else:
|
10 |
+
raise FileNotFoundError('config.yml not found Aborting!')
|
11 |
+
|
12 |
+
OPENAI_ORG_KEY = os.getenv("OPENAI_ORG_KEY", "")
|
13 |
+
OPENAI_KEY = os.getenv("OPENAI_KEY", "")
|
14 |
+
QDRANT_URL = os.getenv("url", CONFIG["vectorDB"]["kwargs"].get("url", ""))
|
15 |
+
QDRANT_KEY = os.getenv("qdrant_key", CONFIG["vectorDB"]["kwargs"].get("api_key", ""))
|
16 |
+
|
17 |
+
UI_USER = os.getenv("user", "admin")
|
18 |
+
UI_PWD = os.getenv("pwd", "admin")
|
19 |
+
|
20 |
+
CONFIG["embeddings"]["kwargs"]["openai_api_key"] = OPENAI_KEY
|
21 |
+
CONFIG["embeddings"]["kwargs"]["openai_organization"] = OPENAI_ORG_KEY
|
22 |
+
CONFIG["llm"]["kwargs"]["openai_api_key"] = OPENAI_KEY
|
23 |
+
CONFIG["llm"]["kwargs"]["openai_organization"] = OPENAI_ORG_KEY
|
24 |
+
CONFIG["vectorDB"]["kwargs"]["url"] = QDRANT_URL
|
25 |
+
CONFIG["vectorDB"]["kwargs"]["api_key"] = QDRANT_KEY
|
26 |
+
|
27 |
+
# if the history should be stored on AWS DynamoDB
|
28 |
+
# otherwise it will be stored on local FS to the output_path defined in the config.yaml file
|
29 |
+
if CONFIG['chatDB']['class'] == 'DynamoDBChatMessageHistory':
|
30 |
+
CHATDB_TABLE_NAME = os.getenv("CHATDB_TABLE_NAME", CONFIG["chatDB"]["kwargs"].get("table_name", "ChatEurlexHistory"))
|
31 |
+
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", CONFIG["chatDB"]["kwargs"].get("aws_access_key_id", ""))
|
32 |
+
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", CONFIG["chatDB"]["kwargs"].get("aws_secret_access_key", ""))
|
33 |
+
CONFIG["chatDB"]["kwargs"]["table_name"] = CHATDB_TABLE_NAME
|
34 |
+
CONFIG["chatDB"]["kwargs"]["aws_access_key_id"] = AWS_ACCESS_KEY_ID
|
35 |
+
CONFIG["chatDB"]["kwargs"]["aws_secret_access_key"] = AWS_SECRET_ACCESS_KEY
|
config.yaml
CHANGED
@@ -16,8 +16,6 @@ vectorDB:
|
|
16 |
embeddings:
|
17 |
class: OpenAIEmbeddings
|
18 |
kwargs:
|
19 |
-
openai_api_key: ""
|
20 |
-
openai_organization: ""
|
21 |
model: text-embedding-ada-002
|
22 |
|
23 |
|
@@ -26,15 +24,15 @@ llm:
|
|
26 |
use_context_function: True
|
27 |
max_context_size: 6000
|
28 |
kwargs:
|
29 |
-
openai_organization: ""
|
30 |
-
openai_api_key: ""
|
31 |
model_name: gpt-4
|
32 |
temperature: 0.8
|
33 |
|
34 |
|
35 |
chatDB:
|
36 |
-
class:
|
37 |
kwargs:
|
38 |
-
|
|
|
|
|
39 |
|
40 |
max_history_messages: 5
|
|
|
16 |
embeddings:
|
17 |
class: OpenAIEmbeddings
|
18 |
kwargs:
|
|
|
|
|
19 |
model: text-embedding-ada-002
|
20 |
|
21 |
|
|
|
24 |
use_context_function: True
|
25 |
max_context_size: 6000
|
26 |
kwargs:
|
|
|
|
|
27 |
model_name: gpt-4
|
28 |
temperature: 0.8
|
29 |
|
30 |
|
31 |
chatDB:
|
32 |
+
class: DynamoDBChatMessageHistory
|
33 |
kwargs:
|
34 |
+
table_name: ''
|
35 |
+
aws_access_key_id: ''
|
36 |
+
aws_secret_access_key: ''
|
37 |
|
38 |
max_history_messages: 5
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ tiktoken==0.6.0
|
|
4 |
qdrant-client==1.7.3
|
5 |
transformers==4.37.2
|
6 |
openai==1.12.0
|
7 |
-
gradio==4.18.0
|
|
|
|
4 |
qdrant-client==1.7.3
|
5 |
transformers==4.37.2
|
6 |
openai==1.12.0
|
7 |
+
gradio==4.18.0
|
8 |
+
boto3==1.34
|