Spaces:
Sleeping
Sleeping
Merge branch 'main' into pdf-render
Browse files- README.md +7 -4
- document_qa/document_qa_engine.py +23 -9
- pyproject.toml +1 -1
- streamlit_app.py +25 -2
README.md
CHANGED
@@ -16,11 +16,14 @@ license: apache-2.0
|
|
16 |
|
17 |
## Introduction
|
18 |
|
19 |
-
Question/Answering on scientific documents using LLMs
|
20 |
-
|
21 |
-
Differently to most of the
|
|
|
22 |
|
23 |
-
|
|
|
|
|
24 |
|
25 |
**Demos**:
|
26 |
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
|
|
|
16 |
|
17 |
## Introduction
|
18 |
|
19 |
+
Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
|
20 |
+
The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
|
21 |
+
Differently to most of the projects, we focus on scientific articles.
|
22 |
+
We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
|
23 |
|
24 |
+
Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
|
25 |
+
|
26 |
+
The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.
|
27 |
|
28 |
**Demos**:
|
29 |
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
|
document_qa/document_qa_engine.py
CHANGED
@@ -23,7 +23,13 @@ class DocumentQAEngine:
|
|
23 |
embeddings_map_from_md5 = {}
|
24 |
embeddings_map_to_md5 = {}
|
25 |
|
26 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self.embedding_function = embedding_function
|
28 |
self.llm = llm
|
29 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
@@ -81,14 +87,14 @@ class DocumentQAEngine:
|
|
81 |
return self.embeddings_map_from_md5[md5]
|
82 |
|
83 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
84 |
-
verbose=False) -> (
|
85 |
Any, str):
|
86 |
# self.load_embeddings(self.embeddings_root_path)
|
87 |
|
88 |
if verbose:
|
89 |
print(query)
|
90 |
|
91 |
-
response = self._run_query(doc_id, query, context_size=context_size)
|
92 |
response = response['output_text'] if 'output_text' in response else response
|
93 |
|
94 |
if verbose:
|
@@ -138,9 +144,15 @@ class DocumentQAEngine:
|
|
138 |
|
139 |
return parsed_output
|
140 |
|
141 |
-
def _run_query(self, doc_id, query, context_size=4):
|
142 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
145 |
|
146 |
def _get_context(self, doc_id, query, context_size=4):
|
@@ -150,6 +162,7 @@ class DocumentQAEngine:
|
|
150 |
return relevant_documents
|
151 |
|
152 |
def get_all_context_by_document(self, doc_id):
|
|
|
153 |
db = self.embeddings_dict[doc_id]
|
154 |
docs = db.get()
|
155 |
return docs['documents']
|
@@ -161,6 +174,7 @@ class DocumentQAEngine:
|
|
161 |
return relevant_documents
|
162 |
|
163 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
|
|
164 |
if verbose:
|
165 |
print("File", pdf_file_path)
|
166 |
filename = Path(pdf_file_path).stem
|
@@ -209,18 +223,17 @@ class DocumentQAEngine:
|
|
209 |
|
210 |
if hash not in self.embeddings_dict.keys():
|
211 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
212 |
-
|
213 |
else:
|
214 |
self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
|
215 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
216 |
collection_name=hash)
|
217 |
|
218 |
-
|
219 |
self.embeddings_root_path = None
|
220 |
|
221 |
return hash
|
222 |
|
223 |
-
def create_embeddings(self, pdfs_dir_path: Path):
|
224 |
input_files = []
|
225 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
226 |
for file_ in files:
|
@@ -238,7 +251,8 @@ class DocumentQAEngine:
|
|
238 |
print(data_path, "exists. Skipping it ")
|
239 |
continue
|
240 |
|
241 |
-
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=
|
|
|
242 |
filename = metadata[0]['filename']
|
243 |
|
244 |
vector_db_document = Chroma.from_texts(texts,
|
|
|
23 |
embeddings_map_from_md5 = {}
|
24 |
embeddings_map_to_md5 = {}
|
25 |
|
26 |
+
def __init__(self,
|
27 |
+
llm,
|
28 |
+
embedding_function,
|
29 |
+
qa_chain_type="stuff",
|
30 |
+
embeddings_root_path=None,
|
31 |
+
grobid_url=None,
|
32 |
+
):
|
33 |
self.embedding_function = embedding_function
|
34 |
self.llm = llm
|
35 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
87 |
return self.embeddings_map_from_md5[md5]
|
88 |
|
89 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
90 |
+
verbose=False, memory=None) -> (
|
91 |
Any, str):
|
92 |
# self.load_embeddings(self.embeddings_root_path)
|
93 |
|
94 |
if verbose:
|
95 |
print(query)
|
96 |
|
97 |
+
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
|
98 |
response = response['output_text'] if 'output_text' in response else response
|
99 |
|
100 |
if verbose:
|
|
|
144 |
|
145 |
return parsed_output
|
146 |
|
147 |
+
def _run_query(self, doc_id, query, memory=None, context_size=4):
|
148 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
149 |
+
if memory:
|
150 |
+
return self.chain.run(input_documents=relevant_documents,
|
151 |
+
question=query)
|
152 |
+
else:
|
153 |
+
return self.chain.run(input_documents=relevant_documents,
|
154 |
+
question=query,
|
155 |
+
memory=memory)
|
156 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
157 |
|
158 |
def _get_context(self, doc_id, query, context_size=4):
|
|
|
162 |
return relevant_documents
|
163 |
|
164 |
def get_all_context_by_document(self, doc_id):
|
165 |
+
"""Return the full context from the document"""
|
166 |
db = self.embeddings_dict[doc_id]
|
167 |
docs = db.get()
|
168 |
return docs['documents']
|
|
|
174 |
return relevant_documents
|
175 |
|
176 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
177 |
+
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
|
178 |
if verbose:
|
179 |
print("File", pdf_file_path)
|
180 |
filename = Path(pdf_file_path).stem
|
|
|
223 |
|
224 |
if hash not in self.embeddings_dict.keys():
|
225 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
226 |
+
collection_name=hash)
|
227 |
else:
|
228 |
self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
|
229 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
230 |
collection_name=hash)
|
231 |
|
|
|
232 |
self.embeddings_root_path = None
|
233 |
|
234 |
return hash
|
235 |
|
236 |
+
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
|
237 |
input_files = []
|
238 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
239 |
for file_ in files:
|
|
|
251 |
print(data_path, "exists. Skipping it ")
|
252 |
continue
|
253 |
|
254 |
+
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
|
255 |
+
perc_overlap=perc_overlap)
|
256 |
filename = metadata[0]['filename']
|
257 |
|
258 |
vector_db_document = Chroma.from_texts(texts,
|
pyproject.toml
CHANGED
@@ -3,7 +3,7 @@ requires = ["setuptools", "setuptools-scm"]
|
|
3 |
build-backend = "setuptools.build_meta"
|
4 |
|
5 |
[tool.bumpversion]
|
6 |
-
current_version = "0.
|
7 |
commit = "true"
|
8 |
tag = "true"
|
9 |
tag_name = "v{new_version}"
|
|
|
3 |
build-backend = "setuptools.build_meta"
|
4 |
|
5 |
[tool.bumpversion]
|
6 |
+
current_version = "0.3.0"
|
7 |
commit = "true"
|
8 |
tag = "true"
|
9 |
tag_name = "v{new_version}"
|
streamlit_app.py
CHANGED
@@ -7,6 +7,7 @@ from tempfile import NamedTemporaryFile
|
|
7 |
import dotenv
|
8 |
from grobid_quantities.quantities import QuantitiesAPI
|
9 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
|
|
10 |
|
11 |
dotenv.load_dotenv(override=True)
|
12 |
|
@@ -52,6 +53,9 @@ if 'ner_processing' not in st.session_state:
|
|
52 |
if 'uploaded' not in st.session_state:
|
53 |
st.session_state['uploaded'] = False
|
54 |
|
|
|
|
|
|
|
55 |
if 'binary' not in st.session_state:
|
56 |
st.session_state['binary'] = None
|
57 |
|
@@ -82,6 +86,11 @@ def new_file():
|
|
82 |
st.session_state['loaded_embeddings'] = None
|
83 |
st.session_state['doc_id'] = None
|
84 |
st.session_state['uploaded'] = True
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
|
87 |
# @st.cache_resource
|
@@ -112,6 +121,7 @@ def init_qa(model, api_key=None):
|
|
112 |
else:
|
113 |
st.error("The model was not loaded properly. Try reloading. ")
|
114 |
st.stop()
|
|
|
115 |
|
116 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
117 |
|
@@ -183,7 +193,7 @@ with st.sidebar:
|
|
183 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
184 |
|
185 |
st.markdown(
|
186 |
-
":warning: Mistral and Zephyr are
|
187 |
|
188 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
189 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
@@ -219,6 +229,12 @@ with st.sidebar:
|
|
219 |
st.session_state['rqa'][model] = init_qa(model)
|
220 |
# else:
|
221 |
# is_api_key_provided = st.session_state['api_key']
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
left_column, right_column = st.columns([1, 1])
|
223 |
|
224 |
with right_column:
|
@@ -349,7 +365,8 @@ with right_column:
|
|
349 |
elif mode == "LLM":
|
350 |
with st.spinner("Generating response..."):
|
351 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
352 |
-
|
|
|
353 |
|
354 |
if not text_response:
|
355 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
@@ -368,5 +385,11 @@ with right_column:
|
|
368 |
st.write(text_response)
|
369 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
372 |
play_old_messages()
|
|
|
7 |
import dotenv
|
8 |
from grobid_quantities.quantities import QuantitiesAPI
|
9 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
10 |
+
from langchain.memory import ConversationBufferWindowMemory
|
11 |
|
12 |
dotenv.load_dotenv(override=True)
|
13 |
|
|
|
53 |
if 'uploaded' not in st.session_state:
|
54 |
st.session_state['uploaded'] = False
|
55 |
|
56 |
+
if 'memory' not in st.session_state:
|
57 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
58 |
+
|
59 |
if 'binary' not in st.session_state:
|
60 |
st.session_state['binary'] = None
|
61 |
|
|
|
86 |
st.session_state['loaded_embeddings'] = None
|
87 |
st.session_state['doc_id'] = None
|
88 |
st.session_state['uploaded'] = True
|
89 |
+
st.session_state['memory'].clear()
|
90 |
+
|
91 |
+
|
92 |
+
def clear_memory():
|
93 |
+
st.session_state['memory'].clear()
|
94 |
|
95 |
|
96 |
# @st.cache_resource
|
|
|
121 |
else:
|
122 |
st.error("The model was not loaded properly. Try reloading. ")
|
123 |
st.stop()
|
124 |
+
return
|
125 |
|
126 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
127 |
|
|
|
193 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
194 |
|
195 |
st.markdown(
|
196 |
+
":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
|
197 |
|
198 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
199 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
|
|
229 |
st.session_state['rqa'][model] = init_qa(model)
|
230 |
# else:
|
231 |
# is_api_key_provided = st.session_state['api_key']
|
232 |
+
|
233 |
+
st.button(
|
234 |
+
'Reset chat memory.',
|
235 |
+
on_click=clear_memory(),
|
236 |
+
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
|
237 |
+
|
238 |
left_column, right_column = st.columns([1, 1])
|
239 |
|
240 |
with right_column:
|
|
|
365 |
elif mode == "LLM":
|
366 |
with st.spinner("Generating response..."):
|
367 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
368 |
+
context_size=context_size,
|
369 |
+
memory=st.session_state.memory)
|
370 |
|
371 |
if not text_response:
|
372 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
|
|
385 |
st.write(text_response)
|
386 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
387 |
|
388 |
+
for id in range(0, len(st.session_state.messages), 2):
|
389 |
+
question = st.session_state.messages[id]['content']
|
390 |
+
if len(st.session_state.messages) > id + 1:
|
391 |
+
answer = st.session_state.messages[id + 1]['content']
|
392 |
+
st.session_state.memory.save_context({"input": question}, {"output": answer})
|
393 |
+
|
394 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
395 |
play_old_messages()
|