markpeace commited on
Commit
224ff63
1 Parent(s): 8b9c87b

basic chat is working again

Browse files
agent/_create.py CHANGED
@@ -1,27 +1,32 @@
1
 
2
  def agent(payload):
3
 
4
-
5
- from langchain_openai import ChatOpenAI
6
- llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
 
 
 
 
7
 
 
 
 
8
  from agent.toolset import tool_executor, converted_tools
9
- model = llm.bind_functions(converted_tools)
10
 
11
  from langgraph.prebuilt import ToolInvocation
12
  import json
13
  from langchain_core.messages import FunctionMessage
14
 
15
- def should_continue(messages):
16
- last_message = messages[-1]
17
- if "function_call" not in last_message.additional_kwargs: return "end"
18
- else: return "continue"
19
 
20
- def call_model(messages):
21
- response = model.invoke(messages)
 
 
 
22
  return response
23
 
24
- def call_tool(messages):
25
  last_message = messages[-1]
26
  action = ToolInvocation(
27
  tool=last_message.additional_kwargs["function_call"]["name"],
@@ -29,37 +34,70 @@ def agent(payload):
29
  )
30
  response = tool_executor.invoke(action)
31
  function_message = FunctionMessage(content=str(response), name=action.tool)
 
 
 
32
  return function_message
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  from langgraph.graph import MessageGraph, END
35
  workflow = MessageGraph()
36
 
37
- workflow.add_node("agent", call_model)
38
- workflow.add_node("action", call_tool)
39
- workflow.set_entry_point("agent")
40
- workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END} )
41
- workflow.add_edge('action', 'agent')
42
 
43
- from agent.memory import Memory
44
- memory = Memory(payload)
45
 
46
- app = workflow.compile(checkpointer=memory.checkpoints)
 
 
 
47
 
48
- from agent.prompt import prompt
49
- prompt=prompt[memory.isNew]
50
 
51
- input = payload.get("input") or "What is Rise for?"
 
 
 
 
 
 
 
 
 
52
 
53
- prompt = prompt.format(input=input, thread_id=memory.thread_id)
54
 
 
 
 
55
 
 
 
56
 
57
- response = app.invoke(prompt, {"configurable": {"thread_id": memory.thread_id}})
58
- print(response[-1].content)
59
 
60
- #for s in app.stream(prompt, {"configurable": {"thread_id": memory.thread_id}}):
61
- # print(list(s.values())[0])
62
- # print("----")
 
63
 
 
 
 
 
 
 
64
  return response[-1].content
65
 
 
1
 
2
  def agent(payload):
3
 
4
+ DEBUG=True
5
+
6
+ from agent.memory import Memory
7
+ memory = Memory(payload)
8
+
9
+ from agent.jsonencoder import json_parse_chain
10
+ from agent.agent_main import Chain_Main_Agent
11
 
12
+ chain_main_agent = Chain_Main_Agent(memory)
13
+
14
+
15
  from agent.toolset import tool_executor, converted_tools
 
16
 
17
  from langgraph.prebuilt import ToolInvocation
18
  import json
19
  from langchain_core.messages import FunctionMessage
20
 
 
 
 
 
21
 
22
+ def call_main_agent(messages):
23
+ response = chain_main_agent.invoke({"conversation":messages, "thread_id": memory.thread_id})
24
+
25
+ if DEBUG: print("call_main_agent called");
26
+
27
  return response
28
 
29
+ def use_tool(messages):
30
  last_message = messages[-1]
31
  action = ToolInvocation(
32
  tool=last_message.additional_kwargs["function_call"]["name"],
 
34
  )
35
  response = tool_executor.invoke(action)
36
  function_message = FunctionMessage(content=str(response), name=action.tool)
37
+
38
+ if DEBUG: print("Suggesting Tool to use..."+action.tool);
39
+
40
  return function_message
41
 
42
+ def render_output(messages):
43
+
44
+ import json
45
+
46
+ response = json_parse_chain.invoke({"conversation":messages, "thread_id": memory.thread_id})
47
+
48
+ if DEBUG: print("Rendering output");
49
+
50
+ from langchain_core.messages import AIMessage
51
+
52
+ response = json.dumps(response)
53
+ return AIMessage(content=response)
54
+
55
  from langgraph.graph import MessageGraph, END
56
  workflow = MessageGraph()
57
 
58
+ workflow.add_node("main_agent", call_main_agent)
59
+ workflow.add_node("use_tool", use_tool)
60
+ workflow.add_node("render_output", render_output)
 
 
61
 
62
+ workflow.set_entry_point("main_agent")
 
63
 
64
+ def should_continue(messages):
65
+ last_message = messages[-1]
66
+ if "function_call" not in last_message.additional_kwargs: return "render_output"
67
+ else: return "continue"
68
 
 
 
69
 
70
+ workflow.add_conditional_edges(
71
+ "main_agent", should_continue,
72
+ {
73
+ "continue": "use_tool",
74
+ "render_output":"render_output",
75
+ "end": END
76
+ }
77
+ )
78
+ workflow.add_edge('use_tool', 'main_agent')
79
+ workflow.add_edge('render_output', END)
80
 
 
81
 
82
+ app = workflow.compile(checkpointer=memory.checkpoints)
83
+
84
+ from langchain_core.messages import HumanMessage
85
 
86
+ input = payload.get("input") or "What is Rise for?"
87
+ inputs = [HumanMessage(content=input)]
88
 
89
+ response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
 
90
 
91
+ '''
92
+ inputs = [HumanMessage(content="My name is Mark")]
93
+ response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
94
+ print(response[-1].content)
95
 
96
+ inputs = [HumanMessage(content="What is my name?")]
97
+ response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
98
+ print(response[-1].content)
99
+ '''
100
+
101
+ print(response[-1].content)
102
  return response[-1].content
103
 
agent/agent_main.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def Chain_Main_Agent(memory):
3
+
4
+ from langchain_openai import ChatOpenAI
5
+
6
+ from agent.prompt import prompt
7
+ prompt=prompt[memory.isNew]
8
+
9
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
10
+ from agent.toolset import converted_tools
11
+ llm_with_tools = llm.bind_functions(converted_tools)
12
+
13
+ from agent.datastructures import parser
14
+
15
+ chain_main_agent = (
16
+ prompt
17
+ | llm_with_tools
18
+ )
19
+
20
+ return chain_main_agent
agent/datastructures.py CHANGED
@@ -3,6 +3,8 @@ from typing import List, Optional
3
  from enum import Enum
4
  from langchain_core.pydantic_v1 import BaseModel, Field
5
  from langchain.output_parsers import PydanticOutputParser
 
 
6
 
7
  ## DEFINE INPUT FRAMEWORK
8
  class InputSchema(BaseModel):
@@ -25,7 +27,7 @@ class ResponseSchema(BaseModel):
25
  """Always use this to format the final response to the user. This will be passed back to the frontend."""
26
  message: str = Field(description="final answer to respond to the user")
27
  thread_id: int = Field(description="The ID of the checkpointer memory thread that this response is associated with. This is used to keep track of the conversation.")
28
- tools: Optional[List[str]] = Field(description="A list of the tools used to generate the response.")
29
- actions: Optional[List[FrontEndActions]] = Field(description="List of suggested actions that should be passed back to the frontend to display. The use will click these to enact them. ")
30
 
31
- parser = PydanticOutputParser(pydantic_object=ResponseSchema)
 
3
  from enum import Enum
4
  from langchain_core.pydantic_v1 import BaseModel, Field
5
  from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.output_parsers import JsonOutputParser
7
+
8
 
9
  ## DEFINE INPUT FRAMEWORK
10
  class InputSchema(BaseModel):
 
27
  """Always use this to format the final response to the user. This will be passed back to the frontend."""
28
  message: str = Field(description="final answer to respond to the user")
29
  thread_id: int = Field(description="The ID of the checkpointer memory thread that this response is associated with. This is used to keep track of the conversation.")
30
+ tools: List[str] = Field(description="A list of the tools used to generate the response.")
31
+ actions: List[FrontEndActions] = Field(description="List of suggested actions that should be passed back to the frontend to display. The use will click these to enact them. ")
32
 
33
+ parser = JsonOutputParser(pydantic_object=ResponseSchema)
agent/jsonencoder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from agent.datastructures import parser
3
+
4
+
5
+ model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
6
+
7
+ #from langchain_core.utils.function_calling import convert_pydantic_to_openai_function
8
+ #from agent.datastructures import ResponseSchema
9
+
10
+ #converted_tools = [convert_pydantic_to_openai_function(ResponseSchema)]
11
+
12
+ #model.bind_functions(convert_pydantic_to_openai_function(ResponseSchema))
13
+
14
+
15
+ from langchain.prompts import ChatPromptTemplate,PromptTemplate, MessagesPlaceholder,SystemMessagePromptTemplate
16
+
17
+ prompt = PromptTemplate(
18
+ template="""
19
+
20
+ {format_instructions}
21
+ Only provide a single JSON blob, beginning with '{{' and ending with '}}'
22
+ /n {input} /n
23
+
24
+ """,
25
+ input_variables=["input"],
26
+ partial_variables={"format_instructions": parser.get_format_instructions()},
27
+
28
+
29
+ )
30
+
31
+ prompt = ChatPromptTemplate.from_messages(
32
+ [
33
+ ("system", "The thread_id of this conversation is {thread_id}."),
34
+ ("system", "You will be given the chat so far, you should render the final answer as a JSON object"),
35
+ SystemMessagePromptTemplate.from_template("{format_instructions}").format(format_instructions=parser.get_format_instructions()),
36
+ MessagesPlaceholder(variable_name="conversation"),
37
+ ]
38
+ )
39
+
40
+
41
+ json_parse_chain = prompt | model | parser
agent/memory.py CHANGED
@@ -10,8 +10,6 @@ class extended_sqliteSaver(SqliteSaver):
10
  import sqlite3
11
  return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False))
12
 
13
- #memory = extended_sqliteSaver.from_conn_string("_chat_history.sqlite")
14
-
15
  class Memory:
16
 
17
  def __init__(self, payload):
 
10
  import sqlite3
11
  return SqliteSaver(conn=sqlite3.connect(conn_string, check_same_thread=False))
12
 
 
 
13
  class Memory:
14
 
15
  def __init__(self, payload):
agent/prompt.py CHANGED
@@ -1,6 +1,7 @@
1
- from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplate
2
  from agent.datastructures import parser
3
 
 
4
  prompt = {
5
  0: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
6
  ChatPromptTemplate.from_messages([
@@ -21,23 +22,26 @@ prompt = {
21
 
22
  You have been provided with the FrequentlyAskedQuestions tool to answer questions that students might have about the Rise programme and Future me initiative. Please rely on this tool and do not make up answers if you are unsure.
23
 
24
- If a question seems relevant to Rise and Future me, but you are unsure of the answer, you are able to refer it to the Rise team using the EmailTeam tool. Before you do this, please confirm with the user.
25
-
26
- ###########
27
-
28
- ### PROVIDING A FINAL ANSWER ###
29
-
30
- {response_format}
31
-
32
- Never output anything outside of the JSON Blob beginning with '{{' and ending with '}}'
33
-
34
- ###########
35
-
36
-
37
  """
38
 
39
- ).format(response_format=parser.get_format_instructions()),
40
  ("system", "The thread_id of this conversation is {thread_id}."),
41
  ("human", "{input}")
42
  ])
43
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder
2
  from agent.datastructures import parser
3
 
4
+ '''
5
  prompt = {
6
  0: # IF THREAD IS CONTINUING, WE CAN RELY ON THE ORIGINAL PROMPT
7
  ChatPromptTemplate.from_messages([
 
22
 
23
  You have been provided with the FrequentlyAskedQuestions tool to answer questions that students might have about the Rise programme and Future me initiative. Please rely on this tool and do not make up answers if you are unsure.
24
 
25
+ If a question seems relevant to Rise and Future me, but you are unsure of the answer, you are able to refer it to the Rise team using the EmailTeam tool. Before you do this, please confirm with the user.
26
+
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
 
29
+ ),
30
  ("system", "The thread_id of this conversation is {thread_id}."),
31
  ("human", "{input}")
32
  ])
33
+ }
34
+ '''
35
+ prompt = {
36
+
37
+ 0: ChatPromptTemplate.from_messages([
38
+ MessagesPlaceholder(variable_name="conversation")
39
+ ]),
40
+
41
+ 1: ChatPromptTemplate.from_messages([
42
+ ("system", "The thread_id of this conversation is {thread_id}."),
43
+ ("system", "In your answer you should list the tools used to produce this answer"),
44
+ MessagesPlaceholder(variable_name="conversation")
45
+ ])
46
+
47
+ }
test.py CHANGED
@@ -1,62 +1,40 @@
 
1
  from dotenv import load_dotenv
2
  load_dotenv()
3
 
4
- from langchain_community.document_loaders.csv_loader import CSVLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_openai import OpenAIEmbeddings, OpenAI, ChatOpenAI
7
- from langchain_community.vectorstores.faiss import FAISS
8
- from langchain_community.document_loaders import WebBaseLoader
9
- from langchain.agents import tool
10
- from langchain_openai import OpenAIEmbeddings
11
- from langchain_community.vectorstores.faiss import FAISS
12
- from langchain.chains import RetrievalQA
13
- from langchain_openai import OpenAI
14
- from langchain_core.pydantic_v1 import BaseModel, Field
15
- from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder,SystemMessagePromptTemplate
16
- from agent.datastructures import parser
17
- from langchain.text_splitter import CharacterTextSplitter
18
-
19
-
20
- def train():
21
-
22
- documents = CSVLoader(file_path="train/posts.csv").load()
23
-
24
- # Split document in chunks
25
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=30)
26
- docs = text_splitter.split_documents(documents=documents)
27
-
28
- embeddings = OpenAIEmbeddings()
29
- # Create vectors
30
- vectorstore = FAISS.from_documents(docs, embeddings)
31
- # Persist the vectors locally on disk
32
- vectorstore.save_local("_rise_product_db");
33
-
34
- print("trained")
35
-
36
- def go():
37
- # Load from local storage
38
- embeddings = OpenAIEmbeddings()
39
- persisted_vectorstore = FAISS.load_local("_rise_product_db", embeddings)
40
-
41
- # Set Up LLM
42
- from agent.prompt import prompt
43
- llm = ChatOpenAI(model="gpt-4", temperature=0)
44
-
45
- prompt = ChatPromptTemplate.from_messages([
46
-
47
- SystemMessagePromptTemplate.from_template("""
48
-
49
- {response_format}
50
-
51
- {context}
52
-
53
- """,partial_variables={"response_format": parser.get_format_instructions()})
54
- ])
55
-
56
- # Use RetrievalQA chain for orchestration
57
- qa = RetrievalQA.from_chain_type(llm=llm, retriever=persisted_vectorstore.as_retriever(),chain_type_kwargs={"prompt": prompt})
58
- profile = "I would like to be a teacher, can you recommend an activity"
59
- result = qa.invoke("recommend activities relevant to the following profile. Activities cannot have already begun: "+profile)
60
- print(result)
61
-
62
- go();
 
1
+
2
  from dotenv import load_dotenv
3
  load_dotenv()
4
 
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_openai import ChatOpenAI
8
+
9
+ from typing import List
10
+
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain_core.output_parsers import JsonOutputParser
13
+ from langchain_core.pydantic_v1 import BaseModel, Field
14
+ from langchain_openai import ChatOpenAI
15
+
16
+
17
+
18
+ class Joke(BaseModel):
19
+ setup: str = Field(description="question to set up a joke")
20
+ punchline: str = Field(description="answer to resolve the joke")
21
+
22
+ parser = JsonOutputParser(pydantic_object=Joke)
23
+
24
+ prompt = PromptTemplate(
25
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
26
+ input_variables=["query"],
27
+ partial_variables={"format_instructions": parser.get_format_instructions()},
28
+ )
29
+
30
+ model = ChatOpenAI(model="gpt-4")
31
+ output_parser = StrOutputParser()
32
+ from langchain_core.output_parsers import JsonOutputParser
33
+
34
+
35
+ chain = prompt | model | parser
36
+
37
+ response=chain.invoke({"query": "ice cream"})
38
+
39
+ print(response)
40
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/faq.py CHANGED
@@ -6,7 +6,7 @@ def train():
6
  from langchain_community.document_loaders import WebBaseLoader
7
 
8
  documents = WebBaseLoader("https://rise.mmu.ac.uk/what-is-rise/").load()
9
- documents[0].page_content = documents[0].page_content.split("Everything You Need To Know About Rise – Students")[1].strip();
10
 
11
  # Split document in chunks
12
  text_splitter = RecursiveCharacterTextSplitter(
 
6
  from langchain_community.document_loaders import WebBaseLoader
7
 
8
  documents = WebBaseLoader("https://rise.mmu.ac.uk/what-is-rise/").load()
9
+ documents[0].page_content = documents[0].page_content.split("Student FAQ")[1].strip();
10
 
11
  # Split document in chunks
12
  text_splitter = RecursiveCharacterTextSplitter(