captain-awesome commited on
Commit
e1b701a
1 Parent(s): ba7b4e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -6
app.py CHANGED
@@ -15,6 +15,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import os
16
  import transformers
17
  import torch
 
18
  # from dotenv import load_dotenv
19
 
20
  # load_dotenv()
@@ -61,21 +62,49 @@ def get_context_retriever_chain(vector_store,llm):
61
  return retriever_chain
62
 
63
 
64
- def get_conversational_rag_chain(retriever_chain,llm):
65
 
66
- llm=llm
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  template = "Answer the user's questions based on the below context:\n\n{context}"
69
  human_template = "{input}"
70
-
71
  prompt = ChatPromptTemplate.from_messages([
72
  ("system", template),
73
  MessagesPlaceholder(variable_name="chat_history"),
74
  ("user", human_template),
75
  ])
76
-
77
- stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
78
-
 
 
 
 
 
 
 
79
  return create_retrieval_chain(retriever_chain, stuff_documents_chain)
80
 
81
  def get_response(user_input):
 
15
  import os
16
  import transformers
17
  import torch
18
+ from langchain_retrieval import BaseRetrieverChain
19
  # from dotenv import load_dotenv
20
 
21
  # load_dotenv()
 
62
  return retriever_chain
63
 
64
 
65
+ # def get_conversational_rag_chain(retriever_chain,llm):
66
 
67
+ # llm=llm
68
 
69
+ # template = "Answer the user's questions based on the below context:\n\n{context}"
70
+ # human_template = "{input}"
71
+
72
+ # prompt = ChatPromptTemplate.from_messages([
73
+ # ("system", template),
74
+ # MessagesPlaceholder(variable_name="chat_history"),
75
+ # ("user", human_template),
76
+ # ])
77
+
78
+ # stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
79
+
80
+ # return create_retrieval_chain(retriever_chain, stuff_documents_chain)
81
+ def get_conversational_rag_chain(
82
+ retriever_chain: Optional[langchain_retrieval.BaseRetrieverChain],
83
+ llm: Callable[[str], str],
84
+ chat_history: Optional[langchain_core.prompts.chat.ChatPromptValue] = None,
85
+ ) -> langchain_retrieval.BaseRetrieverChain:
86
+
87
+ if not retriever_chain:
88
+ raise ValueError("`retriever_chain` cannot be None or an empty object.")
89
+
90
  template = "Answer the user's questions based on the below context:\n\n{context}"
91
  human_template = "{input}"
92
+
93
  prompt = ChatPromptTemplate.from_messages([
94
  ("system", template),
95
  MessagesPlaceholder(variable_name="chat_history"),
96
  ("user", human_template),
97
  ])
98
+
99
+ def safe_llm(input_str: str) -> str:
100
+ if isinstance(input_str, langchain_core.prompts.chat.ChatPromptValue):
101
+ input_str = str(input_str)
102
+
103
+ # Call the original llm, which should now work correctly
104
+ return llm(input_str)
105
+
106
+ stuff_documents_chain = create_stuff_documents_chain(safe_llm, prompt)
107
+
108
  return create_retrieval_chain(retriever_chain, stuff_documents_chain)
109
 
110
  def get_response(user_input):