nickmuchi commited on
Commit
d748b7a
·
1 Parent(s): 525e646

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +5 -1
functions.py CHANGED
@@ -32,6 +32,7 @@ from langchain.chat_models import ChatOpenAI
32
  from langchain.callbacks.base import CallbackManager
33
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
34
  from langchain.chains import ConversationalRetrievalChain, QAGenerationChain
 
35
 
36
  from langchain.prompts.chat import (
37
  ChatPromptTemplate,
@@ -57,6 +58,8 @@ time_str = time.strftime("%d%m%Y-%H%M%S")
57
  HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
58
  margin-bottom: 2.5rem">{}</div> """
59
 
 
 
60
  #Stuff Chain Type Prompt template
61
 
62
  @st.cache_resource
@@ -230,9 +233,10 @@ def embed_text(query,embedding_model,_docsearch):
230
  chain = ConversationalRetrievalChain.from_llm(chat_llm,
231
  retriever= _docsearch.as_retriever(),
232
  qa_prompt = load_prompt(),
 
233
  return_source_documents=True)
234
 
235
- answer = chain({"question": query, "chat_history": chat_history})
236
 
237
  return answer
238
 
 
32
  from langchain.callbacks.base import CallbackManager
33
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
34
  from langchain.chains import ConversationalRetrievalChain, QAGenerationChain
35
+ from langchain.memory import ConversationBufferMemory
36
 
37
  from langchain.prompts.chat import (
38
  ChatPromptTemplate,
 
58
  HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
59
  margin-bottom: 2.5rem">{}</div> """
60
 
61
+ memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
62
+
63
  #Stuff Chain Type Prompt template
64
 
65
  @st.cache_resource
 
233
  chain = ConversationalRetrievalChain.from_llm(chat_llm,
234
  retriever= _docsearch.as_retriever(),
235
  qa_prompt = load_prompt(),
236
+ memory = memory,
237
  return_source_documents=True)
238
 
239
+ answer = chain({"question": query})
240
 
241
  return answer
242