basic chat is working again
Browse files- agent/_create.py +66 -28
- agent/agent_main.py +20 -0
- agent/datastructures.py +5 -3
- agent/jsonencoder.py +41 -0
- agent/memory.py +0 -2
- agent/prompt.py +20 -16
- test.py +37 -59
- train/faq.py +1 -1
agent/_create.py
CHANGED
@@ -1,27 +1,32 @@
|
|
1 |
|
2 |
def agent(payload):
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
|
|
|
|
|
|
8 |
from agent.toolset import tool_executor, converted_tools
|
9 |
-
model = llm.bind_functions(converted_tools)
|
10 |
|
11 |
from langgraph.prebuilt import ToolInvocation
|
12 |
import json
|
13 |
from langchain_core.messages import FunctionMessage
|
14 |
|
15 |
-
def should_continue(messages):
|
16 |
-
last_message = messages[-1]
|
17 |
-
if "function_call" not in last_message.additional_kwargs: return "end"
|
18 |
-
else: return "continue"
|
19 |
|
20 |
-
def
|
21 |
-
response =
|
|
|
|
|
|
|
22 |
return response
|
23 |
|
24 |
-
def
|
25 |
last_message = messages[-1]
|
26 |
action = ToolInvocation(
|
27 |
tool=last_message.additional_kwargs["function_call"]["name"],
|
@@ -29,37 +34,70 @@ def agent(payload):
|
|
29 |
)
|
30 |
response = tool_executor.invoke(action)
|
31 |
function_message = FunctionMessage(content=str(response), name=action.tool)
|
|
|
|
|
|
|
32 |
return function_message
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
from langgraph.graph import MessageGraph, END
|
35 |
workflow = MessageGraph()
|
36 |
|
37 |
-
workflow.add_node("
|
38 |
-
workflow.add_node("
|
39 |
-
workflow.
|
40 |
-
workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END} )
|
41 |
-
workflow.add_edge('action', 'agent')
|
42 |
|
43 |
-
|
44 |
-
memory = Memory(payload)
|
45 |
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
-
from agent.prompt import prompt
|
49 |
-
prompt=prompt[memory.isNew]
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
prompt = prompt.format(input=input, thread_id=memory.thread_id)
|
54 |
|
|
|
|
|
|
|
55 |
|
|
|
|
|
56 |
|
57 |
-
response = app.invoke(
|
58 |
-
print(response[-1].content)
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return response[-1].content
|
65 |
|
|
|
1 |
|
2 |
def agent(payload):
|
3 |
|
4 |
+
DEBUG=True
|
5 |
+
|
6 |
+
from agent.memory import Memory
|
7 |
+
memory = Memory(payload)
|
8 |
+
|
9 |
+
from agent.jsonencoder import json_parse_chain
|
10 |
+
from agent.agent_main import Chain_Main_Agent
|
11 |
|
12 |
+
chain_main_agent = Chain_Main_Agent(memory)
|
13 |
+
|
14 |
+
|
15 |
from agent.toolset import tool_executor, converted_tools
|
|
|
16 |
|
17 |
from langgraph.prebuilt import ToolInvocation
|
18 |
import json
|
19 |
from langchain_core.messages import FunctionMessage
|
20 |
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
def call_main_agent(messages):
|
23 |
+
response = chain_main_agent.invoke({"conversation":messages, "thread_id": memory.thread_id})
|
24 |
+
|
25 |
+
if DEBUG: print("call_main_agent called");
|
26 |
+
|
27 |
return response
|
28 |
|
29 |
+
def use_tool(messages):
|
30 |
last_message = messages[-1]
|
31 |
action = ToolInvocation(
|
32 |
tool=last_message.additional_kwargs["function_call"]["name"],
|
|
|
34 |
)
|
35 |
response = tool_executor.invoke(action)
|
36 |
function_message = FunctionMessage(content=str(response), name=action.tool)
|
37 |
+
|
38 |
+
if DEBUG: print("Suggesting Tool to use..."+action.tool);
|
39 |
+
|
40 |
return function_message
|
41 |
|
42 |
+
def render_output(messages):
|
43 |
+
|
44 |
+
import json
|
45 |
+
|
46 |
+
response = json_parse_chain.invoke({"conversation":messages, "thread_id": memory.thread_id})
|
47 |
+
|
48 |
+
if DEBUG: print("Rendering output");
|
49 |
+
|
50 |
+
from langchain_core.messages import AIMessage
|
51 |
+
|
52 |
+
response = json.dumps(response)
|
53 |
+
return AIMessage(content=response)
|
54 |
+
|
55 |
from langgraph.graph import MessageGraph, END
|
56 |
workflow = MessageGraph()
|
57 |
|
58 |
+
workflow.add_node("main_agent", call_main_agent)
|
59 |
+
workflow.add_node("use_tool", use_tool)
|
60 |
+
workflow.add_node("render_output", render_output)
|
|
|
|
|
61 |
|
62 |
+
workflow.set_entry_point("main_agent")
|
|
|
63 |
|
64 |
+
def should_continue(messages):
|
65 |
+
last_message = messages[-1]
|
66 |
+
if "function_call" not in last_message.additional_kwargs: return "render_output"
|
67 |
+
else: return "continue"
|
68 |
|
|
|
|
|
69 |
|
70 |
+
workflow.add_conditional_edges(
|
71 |
+
"main_agent", should_continue,
|
72 |
+
{
|
73 |
+
"continue": "use_tool",
|
74 |
+
"render_output":"render_output",
|
75 |
+
"end": END
|
76 |
+
}
|
77 |
+
)
|
78 |
+
workflow.add_edge('use_tool', 'main_agent')
|
79 |
+
workflow.add_edge('render_output', END)
|
80 |
|
|
|
81 |
|
82 |
+
app = workflow.compile(checkpointer=memory.checkpoints)
|
83 |
+
|
84 |
+
from langchain_core.messages import HumanMessage
|
85 |
|
86 |
+
input = payload.get("input") or "What is Rise for?"
|
87 |
+
inputs = [HumanMessage(content=input)]
|
88 |
|
89 |
+
response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
|
|
|
90 |
|
91 |
+
'''
|
92 |
+
inputs = [HumanMessage(content="My name is Mark")]
|
93 |
+
response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
|
94 |
+
print(response[-1].content)
|
95 |
|
96 |
+
inputs = [HumanMessage(content="What is my name?")]
|
97 |
+
response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
|
98 |
+
print(response[-1].content)
|
99 |
+
'''
|
100 |
+
|
101 |
+
print(response[-1].content)
|
102 |
return response[-1].content
|
103 |
|
agent/agent_main.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def Chain_Main_Agent(memory):
|
3 |
+
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
|
6 |
+
from agent.prompt import prompt
|
7 |
+
prompt=prompt[memory.isNew]
|
8 |
+
|
9 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
10 |
+
from agent.toolset import converted_tools
|
11 |
+
llm_with_tools = llm.bind_functions(converted_tools)
|
12 |
+
|
13 |
+
from agent.datastructures import parser
|
14 |
+
|
15 |
+
chain_main_agent = (
|
16 |
+
prompt
|
17 |
+
| llm_with_tools
|
18 |
+
)
|
19 |
+
|
20 |
+
return chain_main_agent
|
agent/datastructures.py
CHANGED
@@ -3,6 +3,8 @@ from typing import List, Optional
|
|
3 |
from enum import Enum
|
4 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
5 |
from langchain.output_parsers import PydanticOutputParser
|
|
|
|
|
6 |
|
7 |
## DEFINE INPUT FRAMEWORK
|
8 |
class InputSchema(BaseModel):
|
@@ -25,7 +27,7 @@ class ResponseSchema(BaseModel):
|
|
25 |
"""Always use this to format the final response to the user. This will be passed back to the frontend."""
|
26 |
message: str = Field(description="final answer to respond to the user")
|
27 |
thread_id: int = Field(description="The ID of the checkpointer memory thread that this response is associated with. This is used to keep track of the conversation.")
|
28 |
-
tools:
|
29 |
-
actions:
|
30 |
|
31 |
-
parser =
|
|
|
3 |
from enum import Enum
|
4 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
5 |
from langchain.output_parsers import PydanticOutputParser
|
6 |
+
from langchain_core.output_parsers import JsonOutputParser
|
7 |
+
|
8 |
|
9 |
## DEFINE INPUT FRAMEWORK
|
10 |
class InputSchema(BaseModel):
|
|
|
27 |
"""Always use this to format the final response to the user. This will be passed back to the frontend."""
|
28 |
message: str = Field(description="final answer to respond to the user")
|
29 |
thread_id: int = Field(description="The ID of the checkpointer memory thread that this response is associated with. This is used to keep track of the conversation.")
|
30 |
+
tools: List[str] = Field(description="A list of the tools used to generate the response.")
|
31 |
+
actions: List[FrontEndActions] = Field(description="List of suggested actions that should be passed back to the frontend to display. The use will click these to enact them. ")
|
32 |
|
33 |
+
parser = JsonOutputParser(pydantic_object=ResponseSchema)
|
agent/jsonencoder.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import ChatOpenAI
|
2 |
+
from agent.datastructures import parser
|
3 |
+
|
4 |
+
|
5 |
+
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
6 |
+
|
7 |
+
#from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
|
8 |
+
#from agent.datastructures import ResponseSchema
|
9 |
+
|
10 |
+
#converted_tools = [convert_pydantic_to_openai_function(ResponseSchema)]
|
11 |
+
|
12 |
+
#model.bind_functions(convert_pydantic_to_openai_function(ResponseSchema))
|
13 |
+
|
14 |
+
|
15 |
+
from langchain.prompts import ChatPromptTemplate,PromptTemplate, MessagesPlaceholder,SystemMessagePromptTemplate
|
16 |
+
|
17 |
+
prompt = PromptTemplate(
|
18 |
+
template="""
|
19 |
+
|
20 |
+
{format_instructions}
|
21 |
+
Only provide a single JSON blob, beginning with '{{' and ending with '}}'
|
22 |
+
/n {input} /n
|
23 |
+
|
24 |
+
""",
|
25 |
+
input_variables=["input"],
|
26 |
+
partial_variables={"format_instructions": parser.get_format_instructions()},
|
27 |
+
|
28 |
+
|
29 |
+
)
|
30 |
+
|
31 |
+
prompt = ChatPromptTemplate.from_messages(
|
32 |
+
[
|
33 |
+
("system", "The thread_id of this conversation is {thread_id}."),
|
34 |
+
("system", "You will be given the chat so far, you should render the final answer as a JSON object"),
|
35 |
+
SystemMessagePromptTemplate.from_template("{format_instructions}").format(format_instructions=parser.get_format_instructions()),
|
36 |
+
MessagesPlaceholder(variable_name="conversation"),
|
37 |
+
]
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
json_parse_chain = prompt | model | parser
|
agent/memory.py
CHANGED
@@ -10,8 +10,6 @@ class extended_sqliteSaver(SqliteSaver):
|
|
10 |
import sqlite3
|
11 |
return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False))
|
12 |
|
13 |
-
#memory = extended_sqliteSaver.from_conn_string("_chat_history.sqlite")
|
14 |
-
|
15 |
class Memory:
|
16 |
|
17 |
def __init__(self, payload):
|
|
|
10 |
import sqlite3
|
11 |
return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False))
|
12 |
|
|
|
|
|
13 |
class Memory:
|
14 |
|
15 |
def __init__(self, payload):
|
agent/prompt.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplate
|
2 |
from agent.datastructures import parser
|
3 |
|
|
|
4 |
prompt = {
|
5 |
0: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
|
6 |
ChatPromptTemplate.from_messages([
|
@@ -21,23 +22,26 @@ prompt = {
|
|
21 |
|
22 |
You have been provided with the FrequentlyAskedQuestions tool to answer questions that students might have about the Rise programme and Future me initiative. Please rely on this tool and do not make up answers if you are unsure.
|
23 |
|
24 |
-
If a question seems relevant to Rise and Future me, but you are unsure of the answer, you are able to refer it to the Rise team using the EmailTeam tool. Before you do this, please confirm with the user.
|
25 |
-
|
26 |
-
###########
|
27 |
-
|
28 |
-
### PROVIDING A FINAL ANSWER ###
|
29 |
-
|
30 |
-
{response_format}
|
31 |
-
|
32 |
-
Never output anything outside of the JSON Blob beginning with '{{' and ending with '}}'
|
33 |
-
|
34 |
-
###########
|
35 |
-
|
36 |
-
|
37 |
"""
|
38 |
|
39 |
-
)
|
40 |
("system", "The thread_id of this conversation is {thread_id}."),
|
41 |
("human", "{input}")
|
42 |
])
|
43 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
|
2 |
from agent.datastructures import parser
|
3 |
|
4 |
+
'''
|
5 |
prompt = {
|
6 |
0: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
|
7 |
ChatPromptTemplate.from_messages([
|
|
|
22 |
|
23 |
You have been provided with the FrequentlyAskedQuestions tool to answer questions that students might have about the Rise programme and Future me initiative. Please rely on this tool and do not make up answers if you are unsure.
|
24 |
|
25 |
+
If a question seems relevant to Rise and Future me, but you are unsure of the answer, you are able to refer it to the Rise team using the EmailTeam tool. Before you do this, please confirm with the user.
|
26 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"""
|
28 |
|
29 |
+
),
|
30 |
("system", "The thread_id of this conversation is {thread_id}."),
|
31 |
("human", "{input}")
|
32 |
])
|
33 |
+
}
|
34 |
+
'''
|
35 |
+
prompt = {
|
36 |
+
|
37 |
+
0: ChatPromptTemplate.from_messages([
|
38 |
+
MessagesPlaceholder(variable_name="conversation")
|
39 |
+
]),
|
40 |
+
|
41 |
+
1: ChatPromptTemplate.from_messages([
|
42 |
+
("system", "The thread_id of this conversation is {thread_id}."),
|
43 |
+
("system", "In your answer you should list the tools used to produce this answer"),
|
44 |
+
MessagesPlaceholder(variable_name="conversation")
|
45 |
+
])
|
46 |
+
|
47 |
+
}
|
test.py
CHANGED
@@ -1,62 +1,40 @@
|
|
|
|
1 |
from dotenv import load_dotenv
|
2 |
load_dotenv()
|
3 |
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from langchain_openai import
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from langchain_openai import
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
# Set Up LLM
|
42 |
-
from agent.prompt import prompt
|
43 |
-
llm = ChatOpenAI(model="gpt-4", temperature=0)
|
44 |
-
|
45 |
-
prompt = ChatPromptTemplate.from_messages([
|
46 |
-
|
47 |
-
SystemMessagePromptTemplate.from_template("""
|
48 |
-
|
49 |
-
{response_format}
|
50 |
-
|
51 |
-
{context}
|
52 |
-
|
53 |
-
""",partial_variables={"response_format": parser.get_format_instructions()})
|
54 |
-
])
|
55 |
-
|
56 |
-
# Use RetrievalQA chain for orchestration
|
57 |
-
qa = RetrievalQA.from_chain_type(llm=llm, retriever=persisted_vectorstore.as_retriever(),chain_type_kwargs={"prompt": prompt})
|
58 |
-
profile = "I would like to be a teacher, can you recommend an activity"
|
59 |
-
result = qa.invoke("recommend activities relevant to the following profile. Activities cannot have already begun: "+profile)
|
60 |
-
print(result)
|
61 |
-
|
62 |
-
go();
|
|
|
1 |
+
|
2 |
from dotenv import load_dotenv
|
3 |
load_dotenv()
|
4 |
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
from langchain.prompts import PromptTemplate
|
12 |
+
from langchain_core.output_parsers import JsonOutputParser
|
13 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
14 |
+
from langchain_openai import ChatOpenAI
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class Joke(BaseModel):
|
19 |
+
setup: str = Field(description="question to set up a joke")
|
20 |
+
punchline: str = Field(description="answer to resolve the joke")
|
21 |
+
|
22 |
+
parser = JsonOutputParser(pydantic_object=Joke)
|
23 |
+
|
24 |
+
prompt = PromptTemplate(
|
25 |
+
template="Answer the user query.\n{format_instructions}\n{query}\n",
|
26 |
+
input_variables=["query"],
|
27 |
+
partial_variables={"format_instructions": parser.get_format_instructions()},
|
28 |
+
)
|
29 |
+
|
30 |
+
model = ChatOpenAI(model="gpt-4")
|
31 |
+
output_parser = StrOutputParser()
|
32 |
+
from langchain_core.output_parsers import JsonOutputParser
|
33 |
+
|
34 |
+
|
35 |
+
chain = prompt | model | parser
|
36 |
+
|
37 |
+
response=chain.invoke({"query": "ice cream"})
|
38 |
+
|
39 |
+
print(response)
|
40 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/faq.py
CHANGED
@@ -6,7 +6,7 @@ def train():
|
|
6 |
from langchain_community.document_loaders import WebBaseLoader
|
7 |
|
8 |
documents = WebBaseLoader("https://rise.mmu.ac.uk/what-is-rise/").load()
|
9 |
-
documents[0].page_content = documents[0].page_content.split("
|
10 |
|
11 |
# Split document in chunks
|
12 |
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
6 |
from langchain_community.document_loaders import WebBaseLoader
|
7 |
|
8 |
documents = WebBaseLoader("https://rise.mmu.ac.uk/what-is-rise/").load()
|
9 |
+
documents[0].page_content = documents[0].page_content.split("Student FAQ")[1].strip();
|
10 |
|
11 |
# Split document in chunks
|
12 |
text_splitter = RecursiveCharacterTextSplitter(
|