sinafarhangdoust commited on
Commit
e34a2a6
·
1 Parent(s): c102038

feat: added the AKN + limited search space version for the Chat-Eurlex

Browse files
Files changed (7) hide show
  1. EurLexChat.py +121 -79
  2. app.py +59 -21
  3. chat_utils.py +33 -9
  4. config.py +13 -3
  5. config.yaml +12 -5
  6. consts.py +73 -0
  7. requirements.txt +4 -3
EurLexChat.py CHANGED
@@ -6,21 +6,26 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
  from langchain_core.tools import StructuredTool
7
  from langchain_core.utils.function_calling import convert_to_openai_tool
8
  from langchain_core.messages import AIMessage
9
- from typing import List
10
  from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer, get_vectorDB_module
11
  from langchain_core.documents.base import Document
12
-
 
13
 
14
  class EurLexChat:
15
  def __init__(self, config: dict):
16
  self.config = config
17
  self.max_history_messages = self.config["max_history_messages"]
 
18
  self.use_functions = (
19
- 'use_context_function' in config["llm"] and
20
- config["llm"]["use_context_function"] and
21
  config["llm"]["class"] == "ChatOpenAI")
22
 
23
- self.embedder, self.llm, self.chatDB_class, self.retriever = get_init_modules(config)
 
 
 
24
  self.max_context_size = config["llm"]["max_context_size"]
25
 
26
  self.prompt = ChatPromptTemplate.from_messages([
@@ -43,17 +48,26 @@ class EurLexChat:
43
  name="get_context",
44
  description="To be used whenever the provided context is empty or the user changes the topic of the conversation and you need the context for the topic. " +
45
  "This function must be called only when is strictly necessary. " +
46
- "This function must not be called if you already have the information to answer the user. ",
47
  args_schema=ContextInput
48
  )
49
 
50
- # self.llm = self.llm.bind(tools=[convert_to_openai_tool(GET_CONTEXT_TOOL)])
51
- self.llm_with_functions = self.llm.bind(tools=[convert_to_openai_tool(GET_CONTEXT_TOOL)])
52
-
53
- chain = self.prompt | RunnableLambda(self._resize_history) | self.llm_with_functions
 
 
 
 
 
54
  else:
55
- chain = self.prompt | RunnableLambda(self._resize_history) | self.llm
56
-
 
 
 
 
57
  self.chain_with_history = RunnableWithMessageHistory(
58
  chain,
59
  self.get_chat_history,
@@ -61,8 +75,7 @@ class EurLexChat:
61
  history_messages_key="history",
62
  )
63
 
64
- self.relevant_documents_pipeline = ( self.retriever | self._parse_documents )
65
-
66
 
67
  def _resize_history(self, input_dict):
68
  """
@@ -77,11 +90,10 @@ class EurLexChat:
77
 
78
  messages = input_dict.messages
79
  if (len(messages) - 2) > self.max_history_messages:
80
- messages = [messages[0]] + messages[-(self.max_history_messages +1):]
81
  input_dict.messages = messages
82
  return input_dict
83
 
84
-
85
  def get_chat_history(self, session_id: str):
86
  """
87
  Retrieve chat history instance for a specific session ID.
@@ -108,7 +120,6 @@ class EurLexChat:
108
  else:
109
  return self.chatDB_class(session_id=session_id, **kwargs)
110
 
111
-
112
  def _parse_documents(self, docs: List[Document]) -> List[dict]:
113
  """
114
  Parse a list of documents into a standardized format.
@@ -126,11 +137,11 @@ class EurLexChat:
126
  parsed_documents.append({
127
  'text': doc.page_content,
128
  'source': doc.metadata["source"],
 
129
  '_id': doc.metadata["_id"]
130
  })
131
  return parsed_documents
132
 
133
-
134
  def _format_context_docs(self, context_docs: List[dict]) -> str:
135
  """
136
  Format a list of documents into a single string.
@@ -147,37 +158,107 @@ class EurLexChat:
147
  context_str += doc['text'] + "\n\n"
148
  return context_str
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- def get_relevant_docs(self, question:str) -> List[dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
  Retrieve relevant documents based on a given question.
 
154
 
155
  Args:
156
  question (str): The question for which relevant documents are retrieved.
 
157
 
158
  Returns:
159
  List[dict]: A list of relevant documents.
160
  """
 
 
 
 
 
 
161
 
162
- docs = self.relevant_documents_pipeline.invoke(question)
 
 
 
 
 
163
  return docs
164
 
165
-
166
- def get_context(self, text:str) -> str:
167
  """
168
  Retrieve context for a given text.
 
169
 
170
  Args:
171
  text (str): The text for which context is retrieved.
 
172
 
173
  Returns:
174
  str: A formatted string containing the relevant documents texts.
175
  """
176
 
177
- docs = self.get_relevant_docs(text)
178
  return self._format_context_docs(docs)
179
 
180
-
181
  def _remove_last_messages(self, session_id:str, n:int) -> None:
182
  """
183
  Remove last n messages from the chat history of a specific session.
@@ -193,7 +274,6 @@ class EurLexChat:
193
  for message in message_history:
194
  chat_history.add_message(message)
195
 
196
-
197
  def _format_history(self, session_id:str) -> str:
198
  """
199
  Format chat history for a specific session into a string.
@@ -211,8 +291,7 @@ class EurLexChat:
211
  formatted_history += f"{message.type}: {message.content}\n\n"
212
  return formatted_history
213
 
214
-
215
- def _resize_context(self, context_docs:List[dict]) -> List[dict]:
216
  """
217
  Resize the dimension of the context in terms of number of tokens.
218
  If the concatenation of document text exceeds max_context_size,
@@ -232,16 +311,24 @@ class EurLexChat:
232
  resized_contexts.append(context_docs[i])
233
  total_len += l
234
  return resized_contexts
235
-
236
- def get_answer(self, session_id:str, question:str, context_docs:List[dict], from_tool:bool=False) -> Answer:
 
 
 
 
 
 
237
  """
238
  Get an answer to a question of a specific session, considering context documents and history messages.
 
239
 
240
  Args:
241
  session_id (str): The session ID for which the answer is retrieved.
242
  question (str): The new user message.
243
  context_docs (List[dict]): A list of documents used as context to answer the user message.
244
  from_tool (bool, optional): Whether the question originates from a tool. Defaults to False.
 
245
 
246
  Returns:
247
  Answer: An object containing the answer along with a new list of context documents
@@ -264,63 +351,18 @@ class EurLexChat:
264
  self.get_chat_history(session_id=session_id).add_message(AIMessage(result.content))
265
  return Answer(answer=result.content, status=-1)
266
  text = eval(result.additional_kwargs['tool_calls'][0]['function']['arguments'])['text']
267
- new_docs = self.get_relevant_docs(text)
268
  self._remove_last_messages(session_id=session_id, n=2)
269
 
270
  result = self.get_answer(
271
  session_id=session_id,
272
  question=question,
273
  context_docs=new_docs,
274
- from_tool=True
 
275
  )
276
  if result.status == 1:
277
  return Answer(answer=result.answer, new_documents=new_docs)
278
  else:
279
- return Answer(answer=result.answer)
280
- return Answer(answer=result.content)
281
-
282
-
283
- class EurLexChatAkn(EurLexChat):
284
- def _parse_documents(self, docs: List[Document]) -> List[dict]:
285
- """
286
- Parse a list of documents into a standardized format.
287
-
288
- Args:
289
- docs (List[Document]): A list of documents to parse.
290
-
291
- Returns:
292
- List[dict]: A list of dictionaries, each containing parsed information from the input documents.
293
- """
294
-
295
- parsed_documents = []
296
-
297
- for doc in docs:
298
- parsed_documents.append({
299
- 'text': doc.page_content,
300
- 'source': doc.metadata["uri"],
301
- '_id': doc.metadata["uri"] + doc.metadata["article_id"]
302
- })
303
- return parsed_documents
304
-
305
- def get_relevant_docs(self, question: str, eurovoc: str = None) -> List[dict]:
306
- """
307
- Retrieve relevant documents based on a given question.
308
-
309
- Args:
310
- question (str): The question for which relevant documents are retrieved.
311
- eurovoc (str): The Eurovoc to be used as filter
312
-
313
- Returns:
314
- List[dict]: A list of relevant documents.
315
- """
316
- if eurovoc:
317
- retriever = get_vectorDB_module(
318
- self.config['vectorDB'], self.embedder, metadata={'filter': {'eurovoc': ''}}
319
- )
320
- relevant_documents_pipeline_with_filter = (retriever | self._parse_documents)
321
- docs = relevant_documents_pipeline_with_filter.invoke(
322
- question
323
- )
324
- else:
325
- docs = self.relevant_documents_pipeline.invoke(question)
326
- return docs
 
6
  from langchain_core.tools import StructuredTool
7
  from langchain_core.utils.function_calling import convert_to_openai_tool
8
  from langchain_core.messages import AIMessage
9
+ from typing import List, Optional
10
  from chat_utils import get_init_modules, SYSTEM_PROMPT, SYSTEM_PROMPT_LOOP, ContextInput, Answer, get_vectorDB_module
11
  from langchain_core.documents.base import Document
12
+ from langchain_core.runnables import ConfigurableField
13
+ import qdrant_client.models as rest
14
 
15
  class EurLexChat:
16
  def __init__(self, config: dict):
17
  self.config = config
18
  self.max_history_messages = self.config["max_history_messages"]
19
+ self.vectorDB_class = self.config['vectorDB']['class']
20
  self.use_functions = (
21
+ 'use_context_function' in config["llm"] and
22
+ config["llm"]["use_context_function"] and
23
  config["llm"]["class"] == "ChatOpenAI")
24
 
25
+ self.embedder, self.llm, self.chatDB_class, self.retriever, retriever_chain = get_init_modules(
26
+ config)
27
+
28
+
29
  self.max_context_size = config["llm"]["max_context_size"]
30
 
31
  self.prompt = ChatPromptTemplate.from_messages([
 
48
  name="get_context",
49
  description="To be used whenever the provided context is empty or the user changes the topic of the conversation and you need the context for the topic. " +
50
  "This function must be called only when is strictly necessary. " +
51
+ "This function must not be called if you already have in the context the information to answer the user. ",
52
  args_schema=ContextInput
53
  )
54
 
55
+ self.llm_with_functions = self.llm.bind(
56
+ tools=[convert_to_openai_tool(GET_CONTEXT_TOOL)]
57
+ )
58
+
59
+ chain = (
60
+ self.prompt |
61
+ RunnableLambda(self._resize_history) |
62
+ self.llm_with_functions
63
+ )
64
  else:
65
+ chain = (
66
+ self.prompt |
67
+ RunnableLambda(self._resize_history) |
68
+ self.llm
69
+ )
70
+
71
  self.chain_with_history = RunnableWithMessageHistory(
72
  chain,
73
  self.get_chat_history,
 
75
  history_messages_key="history",
76
  )
77
 
78
+ self.relevant_documents_pipeline = (retriever_chain | self._parse_documents)
 
79
 
80
  def _resize_history(self, input_dict):
81
  """
 
90
 
91
  messages = input_dict.messages
92
  if (len(messages) - 2) > self.max_history_messages:
93
+ messages = [messages[0]] + messages[-(self.max_history_messages + 1):]
94
  input_dict.messages = messages
95
  return input_dict
96
 
 
97
  def get_chat_history(self, session_id: str):
98
  """
99
  Retrieve chat history instance for a specific session ID.
 
120
  else:
121
  return self.chatDB_class(session_id=session_id, **kwargs)
122
 
 
123
  def _parse_documents(self, docs: List[Document]) -> List[dict]:
124
  """
125
  Parse a list of documents into a standardized format.
 
137
  parsed_documents.append({
138
  'text': doc.page_content,
139
  'source': doc.metadata["source"],
140
+ 'celex': doc.metadata["celex"],
141
  '_id': doc.metadata["_id"]
142
  })
143
  return parsed_documents
144
 
 
145
  def _format_context_docs(self, context_docs: List[dict]) -> str:
146
  """
147
  Format a list of documents into a single string.
 
158
  context_str += doc['text'] + "\n\n"
159
  return context_str
160
 
161
+ def get_ids_from_celexes(self, celex_list: List[str]):
162
+ """
163
+ Retrieve the IDs of the documents given their CELEX numbers.
164
+
165
+ Args:
166
+ celex_list (List[str]): A list of CELEX numbers.
167
+
168
+ Returns:
169
+ List[str]: A list of document IDs corresponding to the provided CELEX numbers
170
+ """
171
+
172
+ if self.vectorDB_class == 'Qdrant':
173
+ scroll_filter = rest.Filter(
174
+ must=[
175
+ rest.FieldCondition(
176
+ key="celex",
177
+ match=rest.MatchAny(any=celex_list),
178
+ )
179
+ ])
180
+ offset = -1
181
+ ids = []
182
+ while not (offset is None and offset != -1):
183
+ if offset == -1:
184
+ offset = None
185
+ points, offset = self.retriever.vectorstore.client.scroll(
186
+ collection_name=self.retriever.vectorstore.collection_name,
187
+ limit=100,
188
+ offset=offset,
189
+ scroll_filter=scroll_filter,
190
+ with_payload=False
191
+ )
192
+ ids.extend([p.id for p in points])
193
+ else:
194
+ NotImplementedError(f"Not supported {self.vectorDB_class} vectorDB class")
195
+ return ids
196
 
197
+ def _get_qdrant_ids_filter(self, ids):
198
+ """
199
+ Returns a Qdrant filter to filter documents based on their IDs.
200
+
201
+ This function acts as a workaround due to a hidden bug in Qdrant
202
+ that prevents correct filtering using CELEX numbers.
203
+
204
+ Args:
205
+ ids (List[str]): A list of document IDs.
206
+
207
+ Returns:
208
+ Qdrant filter: A Qdrant filter to filter documents based on their IDs.
209
+ """
210
+
211
+ filter = rest.Filter(
212
+ must=[
213
+ rest.HasIdCondition(has_id=ids),
214
+ ],
215
+ )
216
+
217
+ return filter
218
+
219
+ def get_relevant_docs(self, question: str, ids_list: Optional[List[str]] = None) -> List[dict]:
220
  """
221
  Retrieve relevant documents based on a given question.
222
+ If ids_list is provided, the search is filtered by the given IDs.
223
 
224
  Args:
225
  question (str): The question for which relevant documents are retrieved.
226
+ ids_list (Optional[List[str]]): A list of document IDs to filter the search results.
227
 
228
  Returns:
229
  List[dict]: A list of relevant documents.
230
  """
231
+ if ids_list:
232
+ search_kwargs = {k:v for k,v in self.retriever.search_kwargs.items()}
233
+ if self.vectorDB_class == 'Qdrant':
234
+ filter = self._get_qdrant_ids_filter(ids_list)
235
+ else:
236
+ raise ValueError(f'Celex filter not supported for {self.vectorDB_class}')
237
 
238
+ search_kwargs.update({'filter': filter})
239
+ docs = self.relevant_documents_pipeline.invoke(
240
+ {'question': question},
241
+ config={"configurable": {"search_kwargs": search_kwargs}})
242
+ else:
243
+ docs = self.relevant_documents_pipeline.invoke({'question': question})
244
  return docs
245
 
246
+ def get_context(self, text: str, ids_list:Optional[List[str]]=None) -> str:
 
247
  """
248
  Retrieve context for a given text.
249
+ If ids_list is provided, the search is filtered by the given IDs.
250
 
251
  Args:
252
  text (str): The text for which context is retrieved.
253
+ ids_list (Optional[List[str]]): A list of document IDs to filter the search results.
254
 
255
  Returns:
256
  str: A formatted string containing the relevant documents texts.
257
  """
258
 
259
+ docs = self.get_relevant_docs(text, ids_list=ids_list)
260
  return self._format_context_docs(docs)
261
 
 
262
  def _remove_last_messages(self, session_id:str, n:int) -> None:
263
  """
264
  Remove last n messages from the chat history of a specific session.
 
274
  for message in message_history:
275
  chat_history.add_message(message)
276
 
 
277
  def _format_history(self, session_id:str) -> str:
278
  """
279
  Format chat history for a specific session into a string.
 
291
  formatted_history += f"{message.type}: {message.content}\n\n"
292
  return formatted_history
293
 
294
+ def _resize_context(self, context_docs: List[dict]) -> List[dict]:
 
295
  """
296
  Resize the dimension of the context in terms of number of tokens.
297
  If the concatenation of document text exceeds max_context_size,
 
311
  resized_contexts.append(context_docs[i])
312
  total_len += l
313
  return resized_contexts
314
+
315
+ def get_answer(self,
316
+ session_id: str,
317
+ question: str,
318
+ context_docs: List[dict],
319
+ from_tool: bool = False,
320
+ ids_list: List[str] = None
321
+ ) -> Answer:
322
  """
323
  Get an answer to a question of a specific session, considering context documents and history messages.
324
+ If ids_list is provided, any search for new context documents is filtered by the given IDs.
325
 
326
  Args:
327
  session_id (str): The session ID for which the answer is retrieved.
328
  question (str): The new user message.
329
  context_docs (List[dict]): A list of documents used as context to answer the user message.
330
  from_tool (bool, optional): Whether the question originates from a tool. Defaults to False.
331
+ ids_list (Optional[List[str]]): A list of document IDs to filter the search results for new context documents.
332
 
333
  Returns:
334
  Answer: An object containing the answer along with a new list of context documents
 
351
  self.get_chat_history(session_id=session_id).add_message(AIMessage(result.content))
352
  return Answer(answer=result.content, status=-1)
353
  text = eval(result.additional_kwargs['tool_calls'][0]['function']['arguments'])['text']
354
+ new_docs = self.get_relevant_docs(text, ids_list=ids_list)
355
  self._remove_last_messages(session_id=session_id, n=2)
356
 
357
  result = self.get_answer(
358
  session_id=session_id,
359
  question=question,
360
  context_docs=new_docs,
361
+ from_tool=True,
362
+ ids_list=ids_list
363
  )
364
  if result.status == 1:
365
  return Answer(answer=result.answer, new_documents=new_docs)
366
  else:
367
+ return Answer(answer=result.answer)
368
+ return Answer(answer=result.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -3,6 +3,9 @@ from EurLexChat import EurLexChat
3
  import random
4
  import string
5
  from config import CONFIG, UI_USER, UI_PWD
 
 
 
6
 
7
  def generate_random_string(length):
8
  # Generate a random string of the specified length
@@ -11,31 +14,59 @@ def generate_random_string(length):
11
  random_string = ''.join(random.choice(characters) for _ in range(length))
12
  return random_string
13
 
14
- class Documents():
15
  def __init__(self) -> None:
16
  self.documents = []
 
17
 
 
 
 
 
 
18
 
19
- chat = EurLexChat(config=CONFIG)
20
- docs = Documents()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def remove_doc(btn):
24
- docs.documents.pop(btn)
25
- new_accordions, new_texts = set_new_docs_ui(docs.documents)
26
  return [*new_accordions, *new_texts]
27
 
28
 
29
- def get_answer(message, history, session_id):
30
  s = session_id
 
 
 
 
 
 
 
 
 
31
  if len(history) == 0:
32
- docs.documents = chat.get_relevant_docs(question=message)
 
33
  s = generate_random_string(7)
34
- result = chat.get_answer(s, message, docs.documents)
35
  history.append((message, result.answer))
36
  if result.new_documents:
37
- docs.documents = result.new_documents
38
- accordions, list_texts = set_new_docs_ui(docs.documents)
39
  return ['', history, gr.Column(scale=1, visible=True), *accordions, *list_texts, s]
40
 
41
 
@@ -44,7 +75,7 @@ def set_new_docs_ui(documents):
44
  new_texts = []
45
  for i in range(len(accordions)):
46
  if i < len(documents):
47
- new_accordions.append(gr.update(accordions[i].elem_id, label=f"{documents[i]['text'][:45]}...", visible=True, open=False))
48
  new_texts.append(gr.update(list_texts[i].elem_id, value=f"{documents[i]['text']}...", visible=True))
49
  else:
50
  new_accordions.append(gr.update(accordions[i].elem_id, label="", visible=False))
@@ -53,15 +84,20 @@ def set_new_docs_ui(documents):
53
 
54
 
55
  def clean_page():
56
- docs.documents = []
57
- accordions, list_texts = set_new_docs_ui(docs.documents)
58
- return ["", [], None, *accordions, *list_texts]
59
 
60
  list_texts = []
61
  accordions = []
62
  states = []
63
  delete_buttons = []
64
 
 
 
 
 
 
65
  block = gr.Blocks()
66
  with block:
67
 
@@ -71,15 +107,16 @@ with block:
71
  state = gr.State(value=None)
72
  with gr.Row():
73
  with gr.Column(scale=3):
 
74
  chatbot = gr.Chatbot()
75
  with gr.Row():
76
- message = gr.Textbox(scale=10)
77
- submit = gr.Button("Send", scale=1)
78
- clear = gr.Button("Clear", scale=1)
79
 
80
  with gr.Column(scale=1, visible=False) as col:
81
  gr.Markdown("""<h3><center>Context documents</center></h3>""")
82
- for i in range(CONFIG['vectorDB']['retriever_args']['search_kwargs']['k']):
83
  with gr.Accordion(label="", elem_id=f'accordion_{i}', open=False) as acc:
84
  list_texts.append(gr.Textbox("", elem_id=f'text_{i}', show_label=False, lines=10))
85
  btn = gr.Button(f"Remove document")
@@ -101,9 +138,10 @@ with block:
101
  Contact us: <a href="mailto:chat-eur-lex@igsg.cnr.it">chat-eur-lex@igsg.cnr.it</a>.</p>
102
  </div>""")
103
 
104
- clear.click(clean_page, outputs=[message, chatbot, state, *accordions, *list_texts])
105
- message.submit(get_answer, inputs=[message, chatbot, state], outputs=[message, chatbot, col, *accordions, *list_texts, state])
106
- submit.click(get_answer, inputs=[message, chatbot, state], outputs=[message, chatbot, col, *accordions, *list_texts, state])
 
107
  for i, b in enumerate(delete_buttons):
108
  b.click(remove_doc, inputs=states[i], outputs=[*accordions, *list_texts])
109
 
 
3
  import random
4
  import string
5
  from config import CONFIG, UI_USER, UI_PWD
6
+ from consts import JUSTICE_CELEXES, POLLUTION_CELEXES
7
+ from enum import Enum
8
+ import regex as re
9
 
10
  def generate_random_string(length):
11
  # Generate a random string of the specified length
 
14
  random_string = ''.join(random.choice(characters) for _ in range(length))
15
  return random_string
16
 
17
+ class ChatBot():
18
  def __init__(self) -> None:
19
  self.documents = []
20
+ self.chat = EurLexChat(config=CONFIG)
21
 
22
+ class Versions(Enum):
23
+ AKN='Akoma Ntoso'
24
+ JUSTICE='Organisation of the legal system (1226) eurovoc'
25
+ POLLUTION='Pollution (2524) eurovoc'
26
+ BASIC='All eurovoc'
27
 
 
 
28
 
29
+ bot = ChatBot()
30
+
31
+ justice_ids = bot.chat.get_ids_from_celexes(JUSTICE_CELEXES)
32
+ pollution_ids = bot.chat.get_ids_from_celexes(POLLUTION_CELEXES)
33
+
34
+
35
+ def reinit(version):
36
+ bot.documents = []
37
+ if version == Versions.AKN.value:
38
+ CONFIG['vectorDB']['kwargs']['collection_name'] += "-akn"
39
+ else:
40
+ CONFIG['vectorDB']['kwargs']['collection_name'] = re.sub(r'-akn$', '', CONFIG['vectorDB']['kwargs']['collection_name'])
41
+ bot.chat = EurLexChat(config=CONFIG)
42
+ return clean_page()
43
 
44
  def remove_doc(btn):
45
+ bot.documents.pop(btn)
46
+ new_accordions, new_texts = set_new_docs_ui(bot.documents)
47
  return [*new_accordions, *new_texts]
48
 
49
 
50
+ def get_answer(message, history, session_id, celex_type):
51
  s = session_id
52
+ if celex_type == Versions.JUSTICE.value:
53
+ ids_list = justice_ids
54
+ elif celex_type == Versions.POLLUTION.value:
55
+ ids_list = pollution_ids
56
+ elif celex_type == Versions.BASIC.value or celex_type == Versions.AKN.value:
57
+ ids_list = None
58
+ else:
59
+ raise ValueError(f'Wrong celex_type: {celex_type}')
60
+
61
  if len(history) == 0:
62
+ bot.documents = []
63
+ #docs.documents = chat.get_relevant_docs(question=message, ids_list=ids_list)
64
  s = generate_random_string(7)
65
+ result = bot.chat.get_answer(s, message, bot.documents, ids_list=ids_list)
66
  history.append((message, result.answer))
67
  if result.new_documents:
68
+ bot.documents = result.new_documents
69
+ accordions, list_texts = set_new_docs_ui(bot.documents)
70
  return ['', history, gr.Column(scale=1, visible=True), *accordions, *list_texts, s]
71
 
72
 
 
75
  new_texts = []
76
  for i in range(len(accordions)):
77
  if i < len(documents):
78
+ new_accordions.append(gr.update(accordions[i].elem_id, label=f"{documents[i]['celex']}: {documents[i]['text'][:40]}...", visible=True, open=False))
79
  new_texts.append(gr.update(list_texts[i].elem_id, value=f"{documents[i]['text']}...", visible=True))
80
  else:
81
  new_accordions.append(gr.update(accordions[i].elem_id, label="", visible=False))
 
84
 
85
 
86
  def clean_page():
87
+ bot.documents = []
88
+ accordions, list_texts = set_new_docs_ui(bot.documents)
89
+ return ["", [], None, *accordions, *list_texts, gr.Column(visible=False)]
90
 
91
  list_texts = []
92
  accordions = []
93
  states = []
94
  delete_buttons = []
95
 
96
+ if CONFIG['vectorDB'].get('rerank'):
97
+ n_context_docs = CONFIG['vectorDB']['rerank']['kwargs']['top_n']
98
+ else:
99
+ n_context_docs = CONFIG['vectorDB']['retriever_args']['search_kwargs']['k']
100
+
101
  block = gr.Blocks()
102
  with block:
103
 
 
107
  state = gr.State(value=None)
108
  with gr.Row():
109
  with gr.Column(scale=3):
110
+ drop_down = gr.Dropdown(label='Choose a version', choices=[attribute.value for attribute in Versions], value=Versions.BASIC)
111
  chatbot = gr.Chatbot()
112
  with gr.Row():
113
+ message = gr.Textbox(scale=10,label='',placeholder='Write a message...', container=False)
114
+ submit = gr.Button("Send message", scale=1)
115
+ clear = gr.Button("Reset chat", scale=1)
116
 
117
  with gr.Column(scale=1, visible=False) as col:
118
  gr.Markdown("""<h3><center>Context documents</center></h3>""")
119
+ for i in range(n_context_docs):
120
  with gr.Accordion(label="", elem_id=f'accordion_{i}', open=False) as acc:
121
  list_texts.append(gr.Textbox("", elem_id=f'text_{i}', show_label=False, lines=10))
122
  btn = gr.Button(f"Remove document")
 
138
  Contact us: <a href="mailto:chat-eur-lex@igsg.cnr.it">chat-eur-lex@igsg.cnr.it</a>.</p>
139
  </div>""")
140
 
141
+ drop_down.change(reinit, inputs=[drop_down], outputs=[message, chatbot, state, *accordions, *list_texts, col])
142
+ clear.click(clean_page, outputs=[message, chatbot, state, *accordions, *list_texts, col])
143
+ message.submit(get_answer, inputs=[message, chatbot, state, drop_down], outputs=[message, chatbot, col, *accordions, *list_texts, state])
144
+ submit.click(get_answer, inputs=[message, chatbot, state, drop_down], outputs=[message, chatbot, col, *accordions, *list_texts, state])
145
  for i, b in enumerate(delete_buttons):
146
  b.click(remove_doc, inputs=states[i], outputs=[*accordions, *list_texts])
147
 
chat_utils.py CHANGED
@@ -1,6 +1,9 @@
1
  from dataclasses import dataclass
2
  from typing import Optional, List
3
  from langchain.pydantic_v1 import BaseModel, Field
 
 
 
4
 
5
  SYSTEM_PROMPT = (
6
  "You are an assistant specialized in the legal and compliance field who must answer and converse with the user using the context provided. " +
@@ -59,12 +62,11 @@ def get_init_modules(config):
59
  mod_chat = __import__("langchain_community.chat_message_histories",
60
  fromlist=[config["chatDB"]["class"]])
61
  chatDB_class = getattr(mod_chat, config["chatDB"]["class"])
62
- retriever = get_vectorDB_module(config['vectorDB'], embedder)
63
 
64
- return embedder, llm, chatDB_class, retriever
65
 
66
-
67
- def get_vectorDB_module(db_config, embedder, metadata=None):
68
  mod_chat = __import__("langchain_community.vectorstores",
69
  fromlist=[db_config["class"]])
70
  vectorDB_class = getattr(mod_chat, db_config["class"])
@@ -85,13 +87,10 @@ def get_vectorDB_module(db_config, embedder, metadata=None):
85
 
86
  client = QdrantClient(**client_kwargs)
87
 
88
- if metadata is None:
89
- metadata = {}
90
  retriever = vectorDB_class(
91
  client, embeddings=embedder, **db_kwargs).as_retriever(
92
  search_type=db_config["retriever_args"]["search_type"],
93
- search_kwargs={**db_config["retriever_args"]["search_kwargs"], **metadata},
94
- filter=metadata
95
  )
96
 
97
  else:
@@ -100,4 +99,29 @@ def get_vectorDB_module(db_config, embedder, metadata=None):
100
  search_kwargs=db_config["retriever_args"]["search_kwargs"]
101
  )
102
 
103
- return retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
  from typing import Optional, List
3
  from langchain.pydantic_v1 import BaseModel, Field
4
+ from langchain_core.runnables import ConfigurableField
5
+ from langchain_core.runnables.base import RunnableLambda
6
+ from operator import itemgetter
7
 
8
  SYSTEM_PROMPT = (
9
  "You are an assistant specialized in the legal and compliance field who must answer and converse with the user using the context provided. " +
 
62
  mod_chat = __import__("langchain_community.chat_message_histories",
63
  fromlist=[config["chatDB"]["class"]])
64
  chatDB_class = getattr(mod_chat, config["chatDB"]["class"])
65
+ retriever, retriever_chain = get_vectorDB_module(config['vectorDB'], embedder)
66
 
67
+ return embedder, llm, chatDB_class, retriever, retriever_chain
68
 
69
+ def get_vectorDB_module(db_config, embedder):
 
70
  mod_chat = __import__("langchain_community.vectorstores",
71
  fromlist=[db_config["class"]])
72
  vectorDB_class = getattr(mod_chat, db_config["class"])
 
87
 
88
  client = QdrantClient(**client_kwargs)
89
 
 
 
90
  retriever = vectorDB_class(
91
  client, embeddings=embedder, **db_kwargs).as_retriever(
92
  search_type=db_config["retriever_args"]["search_type"],
93
+ search_kwargs={**db_config["retriever_args"]["search_kwargs"]}
 
94
  )
95
 
96
  else:
 
99
  search_kwargs=db_config["retriever_args"]["search_kwargs"]
100
  )
101
 
102
+ retriever = retriever.configurable_fields(
103
+ search_kwargs=ConfigurableField(
104
+ id="search_kwargs",
105
+ name="Search Kwargs",
106
+ description="The search kwargs to use. Includes dynamic category adjustment.",
107
+ )
108
+ )
109
+
110
+ chain = ( RunnableLambda(lambda x: x['question']) | retriever)
111
+
112
+ if db_config.get("rerank"):
113
+ if db_config["rerank"]["class"] == "CohereRerank":
114
+ module_compressors = __import__("langchain.retrievers.document_compressors",
115
+ fromlist=[db_config["rerank"]["class"]])
116
+ rerank_class = getattr(module_compressors, db_config["rerank"]["class"])
117
+ rerank = rerank_class(**db_config["rerank"]["kwargs"])
118
+
119
+ chain = ({
120
+ "docs": chain,
121
+ "query": itemgetter("question"),
122
+ } | (RunnableLambda(lambda x: rerank.compress_documents(x['docs'], x['query'])))
123
+ )
124
+ else:
125
+ raise NotImplementedError(db_config["rerank"]["class"])
126
+ return retriever, chain
127
+
config.py CHANGED
@@ -24,12 +24,22 @@ CONFIG["llm"]["kwargs"]["openai_organization"] = OPENAI_ORG_KEY
24
  CONFIG["vectorDB"]["kwargs"]["url"] = QDRANT_URL
25
  CONFIG["vectorDB"]["kwargs"]["api_key"] = QDRANT_KEY
26
 
 
27
  # if the history should be stored on AWS DynamoDB
28
  # otherwise it will be stored on local FS to the output_path defined in the config.yaml file
29
  if CONFIG['chatDB']['class'] == 'DynamoDBChatMessageHistory':
30
- CHATDB_TABLE_NAME = os.getenv("CHATDB_TABLE_NAME", CONFIG["chatDB"]["kwargs"].get("table_name", "ChatEurlexHistory"))
31
- AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", CONFIG["chatDB"]["kwargs"].get("aws_access_key_id", ""))
32
- AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", CONFIG["chatDB"]["kwargs"].get("aws_secret_access_key", ""))
 
 
 
33
  CONFIG["chatDB"]["kwargs"]["table_name"] = CHATDB_TABLE_NAME
34
  CONFIG["chatDB"]["kwargs"]["aws_access_key_id"] = AWS_ACCESS_KEY_ID
35
  CONFIG["chatDB"]["kwargs"]["aws_secret_access_key"] = AWS_SECRET_ACCESS_KEY
 
 
 
 
 
 
 
24
  CONFIG["vectorDB"]["kwargs"]["url"] = QDRANT_URL
25
  CONFIG["vectorDB"]["kwargs"]["api_key"] = QDRANT_KEY
26
 
27
+
28
  # if the history should be stored on AWS DynamoDB
29
  # otherwise it will be stored on local FS to the output_path defined in the config.yaml file
30
  if CONFIG['chatDB']['class'] == 'DynamoDBChatMessageHistory':
31
+ CHATDB_TABLE_NAME = os.getenv("CHATDB_TABLE_NAME",
32
+ CONFIG["chatDB"]["kwargs"].get("table_name", "ChatEurlexHistory"))
33
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID",
34
+ CONFIG["chatDB"]["kwargs"].get("aws_access_key_id", ""))
35
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY",
36
+ CONFIG["chatDB"]["kwargs"].get("aws_secret_access_key", ""))
37
  CONFIG["chatDB"]["kwargs"]["table_name"] = CHATDB_TABLE_NAME
38
  CONFIG["chatDB"]["kwargs"]["aws_access_key_id"] = AWS_ACCESS_KEY_ID
39
  CONFIG["chatDB"]["kwargs"]["aws_secret_access_key"] = AWS_SECRET_ACCESS_KEY
40
+
41
+ # if the Cohere reranking is enabled look for the api key and assign it to the CONFIG
42
+ if CONFIG['vectorDB'].get('rerank'):
43
+ COHERE_KEY = os.getenv("COHERE_API_KEY",
44
+ CONFIG["vectorDB"]["rerank"]["kwargs"].get("cohere_api_key", ""))
45
+ CONFIG["vectorDB"]["rerank"]["kwargs"]["cohere_api_key"] = COHERE_KEY
config.yaml CHANGED
@@ -4,15 +4,22 @@ vectorDB:
4
  url: ""
5
  api_key: ""
6
  collection_name: chat-eur-lex
 
7
 
8
  retriever_args:
9
  search_type: mmr
10
  search_kwargs:
11
- k: 15
12
  fetch_k: 300
13
- score_threshold: 0.0
14
  lambda_mult: 0.8
15
 
 
 
 
 
 
 
 
16
  embeddings:
17
  class: OpenAIEmbeddings
18
  kwargs:
@@ -22,9 +29,9 @@ embeddings:
22
  llm:
23
  class: ChatOpenAI
24
  use_context_function: True
25
- max_context_size: 6000
26
  kwargs:
27
- model_name: gpt-4
28
  temperature: 0.8
29
 
30
 
@@ -35,4 +42,4 @@ chatDB:
35
  aws_access_key_id: ''
36
  aws_secret_access_key: ''
37
 
38
- max_history_messages: 5
 
4
  url: ""
5
  api_key: ""
6
  collection_name: chat-eur-lex
7
+ timeout: 60
8
 
9
  retriever_args:
10
  search_type: mmr
11
  search_kwargs:
12
+ k: 100
13
  fetch_k: 300
 
14
  lambda_mult: 0.8
15
 
16
+ rerank:
17
+ class: CohereRerank
18
+ kwargs:
19
+ cohere_api_key: ""
20
+ model: rerank-multilingual-v3.0
21
+ top_n: 15
22
+
23
  embeddings:
24
  class: OpenAIEmbeddings
25
  kwargs:
 
29
  llm:
30
  class: ChatOpenAI
31
  use_context_function: True
32
+ max_context_size: 12000
33
  kwargs:
34
+ model_name: gpt-4o
35
  temperature: 0.8
36
 
37
 
 
42
  aws_access_key_id: ''
43
  aws_secret_access_key: ''
44
 
45
+ max_history_messages: 10
consts.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ JUSTICE_CELEXES =[
2
+ "32024D0414",
3
+ "32023D2098",
4
+ "32023D0133",
5
+ "32022D0998",
6
+ "32022D0494",
7
+ "32021D1711",
8
+ "32021D1943",
9
+ "32021R0693",
10
+ "32020D1117",
11
+ "32019D1798",
12
+ "32019D1564",
13
+ "32019R1111",
14
+ "32019D0844",
15
+ "32019R0629",
16
+ "32019D0598",
17
+ "32018R1990",
18
+ "32018R1935",
19
+ "32018D1275",
20
+ "32018D1103",
21
+ "32018D1094",
22
+ "02018D1696-20200711",
23
+ "32018D0856",
24
+ "02017R1939-20210110",
25
+ "32017D0973",
26
+ "32016D1990",
27
+ "32016R1192",
28
+ "32016R1104",
29
+ "32016R1103",
30
+ "32016D0947",
31
+ "32016D0954",
32
+ "32016D0454",
33
+ "32015R2422",
34
+ "32015D1380",
35
+ "32014R1329",
36
+ "32014D0887",
37
+ "32014D0444",
38
+ "32013L0048",
39
+ "02012R1215-20150110",
40
+ "32012R0650",
41
+ "32011R0969",
42
+ "32009D0026",
43
+ "02009R0004-20150312",
44
+ "32008R0593",
45
+ "32007D0712",
46
+ "32005F0667",
47
+ "32005D0150",
48
+ "32004D0407",
49
+ "32002D0971"
50
+ ]
51
+
52
+ POLLUTION_CELEXES = [
53
+ "32022D0591",
54
+ "02018R0842-20230516",
55
+ "32006D0871",
56
+ "22006A1208(04)",
57
+ "32021R1119",
58
+ "32021R0783",
59
+ "32020R0852",
60
+ "02019R0856-20210811",
61
+ "02017R1369-20210501",
62
+ "32016D1841",
63
+ "22016A1019(01)",
64
+ "32015L2193",
65
+ "02015R0757-20161216",
66
+ "32023R1115",
67
+ "32023R0955",
68
+ "32022D0591",
69
+ "02018R2067-20210101",
70
+ "02018R2067-20210101",
71
+ "32021R1119",
72
+ "32020R1294"
73
+ ]
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
- langchain==0.1.6
2
  lxml==4.9.2
3
- tiktoken==0.6.0
4
  qdrant-client==1.7.3
5
  transformers==4.37.2
6
  openai==1.12.0
7
  gradio==4.18.0
8
- boto3==1.34
 
 
1
+ langchain==0.1.14
2
  lxml==4.9.2
3
+ tiktoken==0.7.0
4
  qdrant-client==1.7.3
5
  transformers==4.37.2
6
  openai==1.12.0
7
  gradio==4.18.0
8
+ boto3==1.34
9
+ cohere==5.5.8