captain-awesome
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -15,6 +15,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
15 |
import os
|
16 |
import transformers
|
17 |
import torch
|
|
|
18 |
# from dotenv import load_dotenv
|
19 |
|
20 |
# load_dotenv()
|
@@ -61,21 +62,49 @@ def get_context_retriever_chain(vector_store,llm):
|
|
61 |
return retriever_chain
|
62 |
|
63 |
|
64 |
-
def get_conversational_rag_chain(retriever_chain,llm):
|
65 |
|
66 |
-
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
template = "Answer the user's questions based on the below context:\n\n{context}"
|
69 |
human_template = "{input}"
|
70 |
-
|
71 |
prompt = ChatPromptTemplate.from_messages([
|
72 |
("system", template),
|
73 |
MessagesPlaceholder(variable_name="chat_history"),
|
74 |
("user", human_template),
|
75 |
])
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
|
80 |
|
81 |
def get_response(user_input):
|
|
|
15 |
import os
|
16 |
import transformers
|
17 |
import torch
|
18 |
+
from langchain_retrieval import BaseRetrieverChain
|
19 |
# from dotenv import load_dotenv
|
20 |
|
21 |
# load_dotenv()
|
|
|
62 |
return retriever_chain
|
63 |
|
64 |
|
65 |
+
# def get_conversational_rag_chain(retriever_chain,llm):
|
66 |
|
67 |
+
# llm=llm
|
68 |
|
69 |
+
# template = "Answer the user's questions based on the below context:\n\n{context}"
|
70 |
+
# human_template = "{input}"
|
71 |
+
|
72 |
+
# prompt = ChatPromptTemplate.from_messages([
|
73 |
+
# ("system", template),
|
74 |
+
# MessagesPlaceholder(variable_name="chat_history"),
|
75 |
+
# ("user", human_template),
|
76 |
+
# ])
|
77 |
+
|
78 |
+
# stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
|
79 |
+
|
80 |
+
# return create_retrieval_chain(retriever_chain, stuff_documents_chain)
|
81 |
+
def get_conversational_rag_chain(
|
82 |
+
retriever_chain: Optional[langchain_retrieval.BaseRetrieverChain],
|
83 |
+
llm: Callable[[str], str],
|
84 |
+
chat_history: Optional[langchain_core.prompts.chat.ChatPromptValue] = None,
|
85 |
+
) -> langchain_retrieval.BaseRetrieverChain:
|
86 |
+
|
87 |
+
if not retriever_chain:
|
88 |
+
raise ValueError("`retriever_chain` cannot be None or an empty object.")
|
89 |
+
|
90 |
template = "Answer the user's questions based on the below context:\n\n{context}"
|
91 |
human_template = "{input}"
|
92 |
+
|
93 |
prompt = ChatPromptTemplate.from_messages([
|
94 |
("system", template),
|
95 |
MessagesPlaceholder(variable_name="chat_history"),
|
96 |
("user", human_template),
|
97 |
])
|
98 |
+
|
99 |
+
def safe_llm(input_str: str) -> str:
|
100 |
+
if isinstance(input_str, langchain_core.prompts.chat.ChatPromptValue):
|
101 |
+
input_str = str(input_str)
|
102 |
+
|
103 |
+
# Call the original llm, which should now work correctly
|
104 |
+
return llm(input_str)
|
105 |
+
|
106 |
+
stuff_documents_chain = create_stuff_documents_chain(safe_llm, prompt)
|
107 |
+
|
108 |
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
|
109 |
|
110 |
def get_response(user_input):
|