Luca Foppiano commited on
Commit
e7425e5
2 Parent(s): 55e39a2 16cf398

Merge pull request #20 from lfoppiano/fix-conversational-memory

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -5,10 +5,12 @@ from typing import Union, Any
5
 
6
  from document_qa.grobid_processors import GrobidProcessor
7
  from grobid_client.grobid_client import GrobidClient
8
- from langchain.chains import create_extraction_chain
9
- from langchain.chains.question_answering import load_qa_chain
 
10
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
11
  from langchain.retrievers import MultiQueryRetriever
 
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain.vectorstores import Chroma
14
  from tqdm import tqdm
@@ -22,15 +24,28 @@ class DocumentQAEngine:
22
  embeddings_map_from_md5 = {}
23
  embeddings_map_to_md5 = {}
24
 
 
 
 
 
 
 
 
25
  def __init__(self,
26
  llm,
27
  embedding_function,
28
  qa_chain_type="stuff",
29
  embeddings_root_path=None,
30
  grobid_url=None,
 
31
  ):
32
  self.embedding_function = embedding_function
33
  self.llm = llm
 
 
 
 
 
34
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
35
 
36
  if embeddings_root_path is not None:
@@ -86,14 +101,14 @@ class DocumentQAEngine:
86
  return self.embeddings_map_from_md5[md5]
87
 
88
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
89
- verbose=False, memory=None) -> (
90
  Any, str):
91
  # self.load_embeddings(self.embeddings_root_path)
92
 
93
  if verbose:
94
  print(query)
95
 
96
- response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
97
  response = response['output_text'] if 'output_text' in response else response
98
 
99
  if verbose:
@@ -143,21 +158,21 @@ class DocumentQAEngine:
143
 
144
  return parsed_output
145
 
146
- def _run_query(self, doc_id, query, context_size=4, memory=None):
147
  relevant_documents = self._get_context(doc_id, query, context_size)
148
- if memory:
149
- return self.chain.run(input_documents=relevant_documents,
150
- question=query)
151
- else:
152
- return self.chain.run(input_documents=relevant_documents,
153
- question=query,
154
- memory=memory)
155
- # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
156
 
157
  def _get_context(self, doc_id, query, context_size=4):
158
  db = self.embeddings_dict[doc_id]
159
  retriever = db.as_retriever(search_kwargs={"k": context_size})
160
  relevant_documents = retriever.get_relevant_documents(query)
 
 
161
  return relevant_documents
162
 
163
  def get_all_context_by_document(self, doc_id):
@@ -239,11 +254,15 @@ class DocumentQAEngine:
239
  hash = metadata[0]['hash']
240
 
241
  if hash not in self.embeddings_dict.keys():
242
- self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
 
 
243
  collection_name=hash)
244
  else:
245
  self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
246
- self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
 
 
247
  collection_name=hash)
248
 
249
  self.embeddings_root_path = None
 
5
 
6
  from document_qa.grobid_processors import GrobidProcessor
7
  from grobid_client.grobid_client import GrobidClient
8
+ from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
9
+ from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
+ map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
+ from langchain.schema import Document
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from langchain.vectorstores import Chroma
16
  from tqdm import tqdm
 
24
  embeddings_map_from_md5 = {}
25
  embeddings_map_to_md5 = {}
26
 
27
+ default_prompts = {
28
+ 'stuff': stuff_prompt,
29
+ 'refine': refine_prompts,
30
+ "map_reduce": map_reduce_prompt,
31
+ "map_rerank": map_rerank_prompt
32
+ }
33
+
34
  def __init__(self,
35
  llm,
36
  embedding_function,
37
  qa_chain_type="stuff",
38
  embeddings_root_path=None,
39
  grobid_url=None,
40
+ memory=None
41
  ):
42
  self.embedding_function = embedding_function
43
  self.llm = llm
44
+ # if memory:
45
+ # prompt = self.default_prompts[qa_chain_type].PROMPT_SELECTOR.get_prompt(llm)
46
+ # self.chain = load_qa_chain(llm, chain_type=qa_chain_type, prompt=prompt, memory=memory)
47
+ # else:
48
+ self.memory = memory
49
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
50
 
51
  if embeddings_root_path is not None:
 
101
  return self.embeddings_map_from_md5[md5]
102
 
103
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
104
+ verbose=False) -> (
105
  Any, str):
106
  # self.load_embeddings(self.embeddings_root_path)
107
 
108
  if verbose:
109
  print(query)
110
 
111
+ response = self._run_query(doc_id, query, context_size=context_size)
112
  response = response['output_text'] if 'output_text' in response else response
113
 
114
  if verbose:
 
158
 
159
  return parsed_output
160
 
161
+ def _run_query(self, doc_id, query, context_size=4):
162
  relevant_documents = self._get_context(doc_id, query, context_size)
163
+ response = self.chain.run(input_documents=relevant_documents,
164
+ question=query)
165
+
166
+ if self.memory:
167
+ self.memory.save_context({"input": query}, {"output": response})
168
+ return response
 
 
169
 
170
  def _get_context(self, doc_id, query, context_size=4):
171
  db = self.embeddings_dict[doc_id]
172
  retriever = db.as_retriever(search_kwargs={"k": context_size})
173
  relevant_documents = retriever.get_relevant_documents(query)
174
+ if self.memory and len(self.memory.buffer_as_messages) > 0:
175
+ relevant_documents.append(Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str)))
176
  return relevant_documents
177
 
178
  def get_all_context_by_document(self, doc_id):
 
254
  hash = metadata[0]['hash']
255
 
256
  if hash not in self.embeddings_dict.keys():
257
+ self.embeddings_dict[hash] = Chroma.from_texts(texts,
258
+ embedding=self.embedding_function,
259
+ metadatas=metadata,
260
  collection_name=hash)
261
  else:
262
  self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
263
+ self.embeddings_dict[hash] = Chroma.from_texts(texts,
264
+ embedding=self.embedding_function,
265
+ metadatas=metadata,
266
  collection_name=hash)
267
 
268
  self.embeddings_root_path = None
streamlit_app.py CHANGED
@@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
 
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
9
  from langchain.memory import ConversationBufferWindowMemory
10
 
@@ -80,6 +81,7 @@ def clear_memory():
80
 
81
  # @st.cache_resource
82
  def init_qa(model, api_key=None):
 
83
  if model == 'chatgpt-3.5-turbo':
84
  if api_key:
85
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
@@ -108,7 +110,7 @@ def init_qa(model, api_key=None):
108
  st.stop()
109
  return
110
 
111
- return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
112
 
113
 
114
  @st.cache_resource
@@ -316,8 +318,7 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
316
  elif mode == "LLM":
317
  with st.spinner("Generating response..."):
318
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
319
- context_size=context_size,
320
- memory=st.session_state.memory)
321
 
322
  if not text_response:
323
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
@@ -336,11 +337,11 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
336
  st.write(text_response)
337
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
338
 
339
- for id in range(0, len(st.session_state.messages), 2):
340
- question = st.session_state.messages[id]['content']
341
- if len(st.session_state.messages) > id + 1:
342
- answer = st.session_state.messages[id + 1]['content']
343
- st.session_state.memory.save_context({"input": question}, {"output": answer})
344
 
345
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
346
  play_old_messages()
 
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
+ from langchain.callbacks import PromptLayerCallbackHandler
9
  from langchain.llms.huggingface_hub import HuggingFaceHub
10
  from langchain.memory import ConversationBufferWindowMemory
11
 
 
81
 
82
  # @st.cache_resource
83
  def init_qa(model, api_key=None):
84
+ ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
85
  if model == 'chatgpt-3.5-turbo':
86
  if api_key:
87
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
 
110
  st.stop()
111
  return
112
 
113
+ return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
114
 
115
 
116
  @st.cache_resource
 
318
  elif mode == "LLM":
319
  with st.spinner("Generating response..."):
320
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
321
+ context_size=context_size)
 
322
 
323
  if not text_response:
324
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
 
337
  st.write(text_response)
338
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
339
 
340
+ # if len(st.session_state.messages) > 1:
341
+ # last_answer = st.session_state.messages[len(st.session_state.messages)-1]
342
+ # if last_answer['role'] == "assistant":
343
+ # last_question = st.session_state.messages[len(st.session_state.messages)-2]
344
+ # st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']})
345
 
346
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
347
  play_old_messages()