sinafarhangdoust commited on
Commit
c102038
·
1 Parent(s): c3f3537

feat: added the ability to log the history to DynamoDB

Browse files
Files changed (6) hide show
  1. EurLexChat.py +56 -2
  2. app.py +4 -38
  3. chat_utils.py +5 -2
  4. config.py +35 -0
  5. config.yaml +4 -6
  6. requirements.txt +2 -1
EurLexChat.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from langchain_community.vectorstores import Qdrant
2
  from langchain_core.runnables.history import RunnableWithMessageHistory
3
  from langchain_core.runnables.base import RunnableLambda
@@ -6,7 +7,7 @@ from langchain_core.tools import StructuredTool
6
  from langchain_core.utils.function_calling import convert_to_openai_tool
7
  from langchain_core.messages import AIMessage
8
  from typing import List
9
- from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer
10
  from langchain_core.documents.base import Document
11
 
12
 
@@ -59,7 +60,7 @@ class EurLexChat:
59
  input_messages_key="question",
60
  history_messages_key="history",
61
  )
62
-
63
  self.relevant_documents_pipeline = ( self.retriever | self._parse_documents )
64
 
65
 
@@ -96,6 +97,14 @@ class EurLexChat:
96
  if self.config["chatDB"]["class"] == 'FileChatMessageHistory':
97
  file_path = f"{kwargs['output_path']}/{session_id}.json"
98
  return self.chatDB_class(file_path=file_path)
 
 
 
 
 
 
 
 
99
  else:
100
  return self.chatDB_class(session_id=session_id, **kwargs)
101
 
@@ -270,3 +279,48 @@ class EurLexChat:
270
  return Answer(answer=result.answer)
271
  return Answer(answer=result.content)
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
  from langchain_community.vectorstores import Qdrant
3
  from langchain_core.runnables.history import RunnableWithMessageHistory
4
  from langchain_core.runnables.base import RunnableLambda
 
7
  from langchain_core.utils.function_calling import convert_to_openai_tool
8
  from langchain_core.messages import AIMessage
9
  from typing import List
10
+ from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer, get_vectorDB_module
11
  from langchain_core.documents.base import Document
12
 
13
 
 
60
  input_messages_key="question",
61
  history_messages_key="history",
62
  )
63
+
64
  self.relevant_documents_pipeline = ( self.retriever | self._parse_documents )
65
 
66
 
 
97
  if self.config["chatDB"]["class"] == 'FileChatMessageHistory':
98
  file_path = f"{kwargs['output_path']}/{session_id}.json"
99
  return self.chatDB_class(file_path=file_path)
100
+ elif self.config["chatDB"]["class"] == 'DynamoDBChatMessageHistory':
101
+ table_name = kwargs["table_name"]
102
+ session = boto3.Session(aws_access_key_id=kwargs["aws_access_key_id"],
103
+ aws_secret_access_key=kwargs["aws_secret_access_key"],
104
+ region_name='eu-west-1')
105
+ return self.chatDB_class(session_id=session_id,
106
+ table_name=table_name,
107
+ boto3_session=session)
108
  else:
109
  return self.chatDB_class(session_id=session_id, **kwargs)
110
 
 
279
  return Answer(answer=result.answer)
280
  return Answer(answer=result.content)
281
 
282
+
283
+ class EurLexChatAkn(EurLexChat):
284
+ def _parse_documents(self, docs: List[Document]) -> List[dict]:
285
+ """
286
+ Parse a list of documents into a standardized format.
287
+
288
+ Args:
289
+ docs (List[Document]): A list of documents to parse.
290
+
291
+ Returns:
292
+ List[dict]: A list of dictionaries, each containing parsed information from the input documents.
293
+ """
294
+
295
+ parsed_documents = []
296
+
297
+ for doc in docs:
298
+ parsed_documents.append({
299
+ 'text': doc.page_content,
300
+ 'source': doc.metadata["uri"],
301
+ '_id': doc.metadata["uri"] + doc.metadata["article_id"]
302
+ })
303
+ return parsed_documents
304
+
305
+ def get_relevant_docs(self, question: str, eurovoc: str = None) -> List[dict]:
306
+ """
307
+ Retrieve relevant documents based on a given question.
308
+
309
+ Args:
310
+ question (str): The question for which relevant documents are retrieved.
311
+ eurovoc (str): The Eurovoc to be used as filter
312
+
313
+ Returns:
314
+ List[dict]: A list of relevant documents.
315
+ """
316
+ if eurovoc:
317
+ retriever = get_vectorDB_module(
318
+ self.config['vectorDB'], self.embedder, metadata={'filter': {'eurovoc': ''}}
319
+ )
320
+ relevant_documents_pipeline_with_filter = (retriever | self._parse_documents)
321
+ docs = relevant_documents_pipeline_with_filter.invoke(
322
+ question
323
+ )
324
+ else:
325
+ docs = self.relevant_documents_pipeline.invoke(question)
326
+ return docs
app.py CHANGED
@@ -1,18 +1,8 @@
1
  import gradio as gr
2
  from EurLexChat import EurLexChat
3
- import yaml
4
  import random
5
  import string
6
- import argparse
7
- import os
8
-
9
- openai_org_key = os.getenv("OPENAI_ORG_KEY")
10
- openai_key = os.getenv("OPENAI_KEY")
11
- ui_pwd = os.getenv("pwd")
12
- ui_user = os.getenv("user")
13
- qdrant_url=os.getenv("url")
14
- qdrant_key=os.getenv("qdrant_key")
15
-
16
 
17
  def generate_random_string(length):
18
  # Generate a random string of the specified length
@@ -25,32 +15,8 @@ class Documents():
25
  def __init__(self) -> None:
26
  self.documents = []
27
 
28
- parser = argparse.ArgumentParser(description="Chat-eur-lex ui")
29
-
30
- parser.add_argument('--config_path',
31
- dest='config_path',
32
- metavar='config_path',
33
- type=str,
34
- help='The path to the config file that contains all the settings for the chat engine' ,
35
- default='config.yaml')
36
- args = parser.parse_args()
37
-
38
- # Read config file
39
- with open(args.config_path, 'r') as file:
40
- config = yaml.safe_load(file)
41
-
42
- config["embeddings"]["kwargs"]["openai_api_key"] = openai_key
43
- config["embeddings"]["kwargs"]["openai_organization"] = openai_org_key
44
- config["llm"]["kwargs"]["openai_api_key"] = openai_key
45
- config["llm"]["kwargs"]["openai_organization"] = openai_org_key
46
- config["vectorDB"]["kwargs"]["url"] = qdrant_url
47
- config["vectorDB"]["kwargs"]["api_key"] = qdrant_key
48
-
49
-
50
-
51
-
52
 
53
- chat = EurLexChat(config=config)
54
  docs = Documents()
55
 
56
 
@@ -113,7 +79,7 @@ with block:
113
 
114
  with gr.Column(scale=1, visible=False) as col:
115
  gr.Markdown("""<h3><center>Context documents</center></h3>""")
116
- for i in range(config['vectorDB']['retriever_args']['search_kwargs']['k']):
117
  with gr.Accordion(label="", elem_id=f'accordion_{i}', open=False) as acc:
118
  list_texts.append(gr.Textbox("", elem_id=f'text_{i}', show_label=False, lines=10))
119
  btn = gr.Button(f"Remove document")
@@ -141,4 +107,4 @@ with block:
141
  for i, b in enumerate(delete_buttons):
142
  b.click(remove_doc, inputs=states[i], outputs=[*accordions, *list_texts])
143
 
144
- block.launch(debug=True, auth=(ui_user, ui_pwd))
 
1
  import gradio as gr
2
  from EurLexChat import EurLexChat
 
3
  import random
4
  import string
5
+ from config import CONFIG, UI_USER, UI_PWD
 
 
 
 
 
 
 
 
 
6
 
7
  def generate_random_string(length):
8
  # Generate a random string of the specified length
 
15
  def __init__(self) -> None:
16
  self.documents = []
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ chat = EurLexChat(config=CONFIG)
20
  docs = Documents()
21
 
22
 
 
79
 
80
  with gr.Column(scale=1, visible=False) as col:
81
  gr.Markdown("""<h3><center>Context documents</center></h3>""")
82
+ for i in range(CONFIG['vectorDB']['retriever_args']['search_kwargs']['k']):
83
  with gr.Accordion(label="", elem_id=f'accordion_{i}', open=False) as acc:
84
  list_texts.append(gr.Textbox("", elem_id=f'text_{i}', show_label=False, lines=10))
85
  btn = gr.Button(f"Remove document")
 
107
  for i, b in enumerate(delete_buttons):
108
  b.click(remove_doc, inputs=states[i], outputs=[*accordions, *list_texts])
109
 
110
+ block.launch(debug=True, auth=(UI_USER, UI_PWD))
chat_utils.py CHANGED
@@ -64,7 +64,7 @@ def get_init_modules(config):
64
  return embedder, llm, chatDB_class, retriever
65
 
66
 
67
- def get_vectorDB_module(db_config, embedder):
68
  mod_chat = __import__("langchain_community.vectorstores",
69
  fromlist=[db_config["class"]])
70
  vectorDB_class = getattr(mod_chat, db_config["class"])
@@ -85,10 +85,13 @@ def get_vectorDB_module(db_config, embedder):
85
 
86
  client = QdrantClient(**client_kwargs)
87
 
 
 
88
  retriever = vectorDB_class(
89
  client, embeddings=embedder, **db_kwargs).as_retriever(
90
  search_type=db_config["retriever_args"]["search_type"],
91
- search_kwargs=db_config["retriever_args"]["search_kwargs"]
 
92
  )
93
 
94
  else:
 
64
  return embedder, llm, chatDB_class, retriever
65
 
66
 
67
+ def get_vectorDB_module(db_config, embedder, metadata=None):
68
  mod_chat = __import__("langchain_community.vectorstores",
69
  fromlist=[db_config["class"]])
70
  vectorDB_class = getattr(mod_chat, db_config["class"])
 
85
 
86
  client = QdrantClient(**client_kwargs)
87
 
88
+ if metadata is None:
89
+ metadata = {}
90
  retriever = vectorDB_class(
91
  client, embeddings=embedder, **db_kwargs).as_retriever(
92
  search_type=db_config["retriever_args"]["search_type"],
93
+ search_kwargs={**db_config["retriever_args"]["search_kwargs"], **metadata},
94
+ filter=metadata
95
  )
96
 
97
  else:
config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+
4
+
5
+ # Read config file
6
+ if os.path.exists('config.yaml'):
7
+ with open('config.yaml', 'r') as file:
8
+ CONFIG = yaml.safe_load(file)
9
+ else:
10
+ raise FileNotFoundError('config.yml not found Aborting!')
11
+
12
+ OPENAI_ORG_KEY = os.getenv("OPENAI_ORG_KEY", "")
13
+ OPENAI_KEY = os.getenv("OPENAI_KEY", "")
14
+ QDRANT_URL = os.getenv("url", CONFIG["vectorDB"]["kwargs"].get("url", ""))
15
+ QDRANT_KEY = os.getenv("qdrant_key", CONFIG["vectorDB"]["kwargs"].get("api_key", ""))
16
+
17
+ UI_USER = os.getenv("user", "admin")
18
+ UI_PWD = os.getenv("pwd", "admin")
19
+
20
+ CONFIG["embeddings"]["kwargs"]["openai_api_key"] = OPENAI_KEY
21
+ CONFIG["embeddings"]["kwargs"]["openai_organization"] = OPENAI_ORG_KEY
22
+ CONFIG["llm"]["kwargs"]["openai_api_key"] = OPENAI_KEY
23
+ CONFIG["llm"]["kwargs"]["openai_organization"] = OPENAI_ORG_KEY
24
+ CONFIG["vectorDB"]["kwargs"]["url"] = QDRANT_URL
25
+ CONFIG["vectorDB"]["kwargs"]["api_key"] = QDRANT_KEY
26
+
27
+ # if the history should be stored on AWS DynamoDB
28
+ # otherwise it will be stored on local FS to the output_path defined in the config.yaml file
29
+ if CONFIG['chatDB']['class'] == 'DynamoDBChatMessageHistory':
30
+ CHATDB_TABLE_NAME = os.getenv("CHATDB_TABLE_NAME", CONFIG["chatDB"]["kwargs"].get("table_name", "ChatEurlexHistory"))
31
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", CONFIG["chatDB"]["kwargs"].get("aws_access_key_id", ""))
32
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", CONFIG["chatDB"]["kwargs"].get("aws_secret_access_key", ""))
33
+ CONFIG["chatDB"]["kwargs"]["table_name"] = CHATDB_TABLE_NAME
34
+ CONFIG["chatDB"]["kwargs"]["aws_access_key_id"] = AWS_ACCESS_KEY_ID
35
+ CONFIG["chatDB"]["kwargs"]["aws_secret_access_key"] = AWS_SECRET_ACCESS_KEY
config.yaml CHANGED
@@ -16,8 +16,6 @@ vectorDB:
16
  embeddings:
17
  class: OpenAIEmbeddings
18
  kwargs:
19
- openai_api_key: ""
20
- openai_organization: ""
21
  model: text-embedding-ada-002
22
 
23
 
@@ -26,15 +24,15 @@ llm:
26
  use_context_function: True
27
  max_context_size: 6000
28
  kwargs:
29
- openai_organization: ""
30
- openai_api_key: ""
31
  model_name: gpt-4
32
  temperature: 0.8
33
 
34
 
35
  chatDB:
36
- class: FileChatMessageHistory
37
  kwargs:
38
- output_path: ./output
 
 
39
 
40
  max_history_messages: 5
 
16
  embeddings:
17
  class: OpenAIEmbeddings
18
  kwargs:
 
 
19
  model: text-embedding-ada-002
20
 
21
 
 
24
  use_context_function: True
25
  max_context_size: 6000
26
  kwargs:
 
 
27
  model_name: gpt-4
28
  temperature: 0.8
29
 
30
 
31
  chatDB:
32
+ class: DynamoDBChatMessageHistory
33
  kwargs:
34
+ table_name: ''
35
+ aws_access_key_id: ''
36
+ aws_secret_access_key: ''
37
 
38
  max_history_messages: 5
requirements.txt CHANGED
@@ -4,4 +4,5 @@ tiktoken==0.6.0
4
  qdrant-client==1.7.3
5
  transformers==4.37.2
6
  openai==1.12.0
7
- gradio==4.18.0
 
 
4
  qdrant-client==1.7.3
5
  transformers==4.37.2
6
  openai==1.12.0
7
+ gradio==4.18.0
8
+ boto3==1.34