Luca Foppiano commited on
Commit
16cf398
2 Parent(s): ad304e2 55e39a2

Merge branch 'main' into fix-conversational-memory

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
 
6
  from grobid_client.grobid_client import GrobidClient
7
  from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
8
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
@@ -14,8 +15,6 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
  from tqdm import tqdm
16
 
17
- from document_qa.grobid_processors import GrobidProcessor
18
-
19
 
20
  class DocumentQAEngine:
21
  llm = None
@@ -188,8 +187,10 @@ class DocumentQAEngine:
188
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
189
  return relevant_documents
190
 
191
- def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
192
- """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
 
 
193
  if verbose:
194
  print("File", pdf_file_path)
195
  filename = Path(pdf_file_path).stem
@@ -204,6 +205,7 @@ class DocumentQAEngine:
204
  texts = []
205
  metadatas = []
206
  ids = []
 
207
  if chunk_size < 0:
208
  for passage in structure['passages']:
209
  biblio_copy = copy.copy(biblio)
@@ -227,10 +229,25 @@ class DocumentQAEngine:
227
  metadatas = [biblio for _ in range(len(texts))]
228
  ids = [id for id, t in enumerate(texts)]
229
 
 
 
 
 
 
 
 
 
 
 
230
  return texts, metadatas, ids
231
 
232
- def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
233
- texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap)
 
 
 
 
 
234
  if doc_id:
235
  hash = doc_id
236
  else:
@@ -252,7 +269,7 @@ class DocumentQAEngine:
252
 
253
  return hash
254
 
255
- def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
256
  input_files = []
257
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
258
  for file_ in files:
@@ -269,9 +286,12 @@ class DocumentQAEngine:
269
  if os.path.exists(data_path):
270
  print(data_path, "exists. Skipping it ")
271
  continue
272
-
273
- texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
274
- perc_overlap=perc_overlap)
 
 
 
275
  filename = metadata[0]['filename']
276
 
277
  vector_db_document = Chroma.from_texts(texts,
 
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
6
+ from document_qa.grobid_processors import GrobidProcessor
7
  from grobid_client.grobid_client import GrobidClient
8
  from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
 
15
  from langchain.vectorstores import Chroma
16
  from tqdm import tqdm
17
 
 
 
18
 
19
  class DocumentQAEngine:
20
  llm = None
 
187
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
188
  return relevant_documents
189
 
190
+ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, include=(), verbose=False):
191
+ """
192
+ Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
193
+ """
194
  if verbose:
195
  print("File", pdf_file_path)
196
  filename = Path(pdf_file_path).stem
 
205
  texts = []
206
  metadatas = []
207
  ids = []
208
+
209
  if chunk_size < 0:
210
  for passage in structure['passages']:
211
  biblio_copy = copy.copy(biblio)
 
229
  metadatas = [biblio for _ in range(len(texts))]
230
  ids = [id for id, t in enumerate(texts)]
231
 
232
+ if "biblio" in include:
233
+ biblio_metadata = copy.copy(biblio)
234
+ biblio_metadata['type'] = "biblio"
235
+ biblio_metadata['section'] = "header"
236
+ for key in ['title', 'authors', 'publication_year']:
237
+ if key in biblio_metadata:
238
+ texts.append("{}: {}".format(key, biblio_metadata[key]))
239
+ metadatas.append(biblio_metadata)
240
+ ids.append(key)
241
+
242
  return texts, metadatas, ids
243
 
244
+ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1, include_biblio=False):
245
+ include = ["biblio"] if include_biblio else []
246
+ texts, metadata, ids = self.get_text_from_document(
247
+ pdf_path,
248
+ chunk_size=chunk_size,
249
+ perc_overlap=perc_overlap,
250
+ include=include)
251
  if doc_id:
252
  hash = doc_id
253
  else:
 
269
 
270
  return hash
271
 
272
+ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False):
273
  input_files = []
274
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
275
  for file_ in files:
 
286
  if os.path.exists(data_path):
287
  print(data_path, "exists. Skipping it ")
288
  continue
289
+ include = ["biblio"] if include_biblio else []
290
+ texts, metadata, ids = self.get_text_from_document(
291
+ input_file,
292
+ chunk_size=chunk_size,
293
+ perc_overlap=perc_overlap,
294
+ include=include)
295
  filename = metadata[0]['filename']
296
 
297
  vector_db_document = Chroma.from_texts(texts,
document_qa/grobid_processors.py CHANGED
@@ -171,7 +171,7 @@ class GrobidProcessor(BaseProcessor):
171
  }
172
  try:
173
  year = dateparser.parse(doc_biblio.header.date).year
174
- biblio["year"] = year
175
  except:
176
  pass
177
 
 
171
  }
172
  try:
173
  year = dateparser.parse(doc_biblio.header.date).year
174
+ biblio["publication_year"] = year
175
  except:
176
  pass
177
 
streamlit_app.py CHANGED
@@ -288,7 +288,8 @@ if uploaded_file and not st.session_state.loaded_embeddings:
288
  # hash = get_file_hash(tmp_file.name)[:10]
289
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
290
  chunk_size=chunk_size,
291
- perc_overlap=0.1)
 
292
  st.session_state['loaded_embeddings'] = True
293
  st.session_state.messages = []
294
 
 
288
  # hash = get_file_hash(tmp_file.name)[:10]
289
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
290
  chunk_size=chunk_size,
291
+ perc_overlap=0.1,
292
+ include_biblio=True)
293
  st.session_state['loaded_embeddings'] = True
294
  st.session_state.messages = []
295