bstraehle commited on
Commit
bd4c825
1 Parent(s): f95e45f

Update rag_langgraph.py

Browse files
Files changed (1) hide show
  1. rag_langgraph.py +110 -131
rag_langgraph.py CHANGED
@@ -1,163 +1,142 @@
1
- import functools, operator
2
- from IPython.display import Image, display
3
- from langchain_core.messages import (
4
- AIMessage,
5
- BaseMessage,
6
- ToolMessage,
7
- HumanMessage,
8
- )
9
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
- from langchain_core.tools import tool
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
- from langchain_experimental.utilities import PythonREPL
 
 
 
 
13
  from langchain_openai import ChatOpenAI
14
- from langgraph.graph import END, StateGraph
15
- from langgraph.prebuilt import ToolNode
16
- from typing import Annotated, Literal, Sequence, TypedDict
17
- from typing_extensions import TypedDict
18
 
19
- def create_agent(llm, tools, system_message: str):
20
- """Create an agent."""
 
 
 
 
 
 
 
 
 
 
21
  prompt = ChatPromptTemplate.from_messages(
22
  [
23
  (
24
  "system",
25
- "You are a helpful AI assistant, collaborating with other assistants."
26
- " Use the provided tools to progress towards answering the question."
27
- " If you are unable to fully answer, that's OK, another assistant with different tools "
28
- " will help where you left off. Execute what you can to make progress."
29
- " If you or any of the other assistants have the final answer or deliverable,"
30
- " prefix your response with FINAL ANSWER so the team knows to stop."
31
- " You have access to the following tools: {tool_names}.\n{system_message}",
32
  ),
33
  MessagesPlaceholder(variable_name="messages"),
 
34
  ]
35
  )
36
- prompt = prompt.partial(system_message=system_message)
37
- prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
38
- return prompt | llm.bind_tools(tools)
39
 
40
- @tool
41
- def python_repl(code: Annotated[str, "The python code to execute to generate your chart."]):
42
- """Use this to execute python code. If you want to see the output of a value,
43
- you should print it out with `print(...)`. This is visible to the user."""
44
- try:
45
- result = repl.run(code)
46
- except BaseException as e:
47
- return f"Failed to execute. Error: {repr(e)}"
48
- result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
49
- return (
50
- result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
51
- )
52
-
53
- # This defines the object that is passed between each node
54
- # in the graph. We will create different nodes for each agent and tool
55
- class AgentState(TypedDict):
56
- messages: Annotated[Sequence[BaseMessage], operator.add]
57
- sender: str
58
-
59
- # Helper function to create a node for a given agent
60
  def agent_node(state, agent, name):
61
  result = agent.invoke(state)
62
- # We convert the agent output into a format that is suitable to append to the global state
63
- if isinstance(result, ToolMessage):
64
- pass
65
- else:
66
- result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
67
- return {
68
- "messages": [result],
69
- # Since we have a strict workflow, we can
70
- # track the sender so we know who to pass to next.
71
- "sender": name,
72
- }
73
-
74
- def router(state) -> Literal["call_tool", "__end__", "continue"]:
75
- # This is the router
76
- messages = state["messages"]
77
- last_message = messages[-1]
78
- if last_message.tool_calls:
79
- # The previous agent is invoking a tool
80
- return "call_tool"
81
- if "FINAL ANSWER" in last_message.content:
82
- # Any agent decided the work is done
83
- return "__end__"
84
- return "continue"
85
 
 
 
 
 
 
 
 
 
86
  def run_multi_agent(prompt):
87
  tavily_tool = TavilySearchResults(max_results=5)
88
  repl = PythonREPL()
89
 
90
- llm = ChatOpenAI(model="gpt-4o")
91
-
92
- # Research agent and node
93
- research_agent = create_agent(
94
- llm,
95
- [tavily_tool],
96
- system_message="You should provide accurate data for the chart_generator to use.",
97
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
99
 
100
- # chart_generator
101
- chart_agent = create_agent(
102
  llm,
103
- [python_repl],
104
- system_message="Any charts you display will be visible by the user.",
105
  )
106
- chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")
107
-
108
- tools = [tavily_tool, python_repl]
109
- tool_node = ToolNode(tools)
110
-
111
- workflow = StateGraph(AgentState)
112
 
 
113
  workflow.add_node("Researcher", research_node)
114
- workflow.add_node("chart_generator", chart_node)
115
- workflow.add_node("call_tool", tool_node)
116
-
117
- workflow.add_conditional_edges(
118
- "Researcher",
119
- router,
120
- {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
121
- )
122
- workflow.add_conditional_edges(
123
- "chart_generator",
124
- router,
125
- {"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
126
- )
127
 
128
- workflow.add_conditional_edges(
129
- "call_tool",
130
- # Each agent node updates the 'sender' field
131
- # the tool calling node does not, meaning
132
- # this edge will route back to the original agent
133
- # who invoked the tool
134
- lambda x: x["sender"],
135
- {
136
- "Researcher": "Researcher",
137
- "chart_generator": "chart_generator",
138
- },
139
- )
140
- workflow.set_entry_point("Researcher")
141
  graph = workflow.compile()
142
 
143
- try:
144
- display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
145
- except:
146
- # This requires some extra dependencies and is optional
147
- pass
148
-
149
- events = graph.stream(
150
  {
151
  "messages": [
152
- HumanMessage(
153
- content=prompt
154
- )
155
- ],
156
- },
157
- # Maximum number of steps to take in the graph
158
- {"recursion_limit": 150},
159
- )
160
- for s in events:
161
- print(s)
162
- print("----")
163
  return "DONE"
 
1
+ from typing import Annotated, List, Tuple, Union
2
+
 
 
 
 
 
 
 
 
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
+ from langchain_core.tools import tool
5
+ from langchain_experimental.tools import PythonREPLTool
6
+
7
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
8
+ from langchain_core.messages import BaseMessage, HumanMessage
9
  from langchain_openai import ChatOpenAI
 
 
 
 
10
 
11
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
+ from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
13
+
14
+ import operator
15
+ from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict
16
+ import functools
17
+
18
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
19
+ from langgraph.graph import StateGraph, END
20
+
21
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
22
+ # Each worker node will be given a name and some tools.
23
  prompt = ChatPromptTemplate.from_messages(
24
  [
25
  (
26
  "system",
27
+ system_prompt,
 
 
 
 
 
 
28
  ),
29
  MessagesPlaceholder(variable_name="messages"),
30
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
31
  ]
32
  )
33
+ agent = create_openai_tools_agent(llm, tools, prompt)
34
+ executor = AgentExecutor(agent=agent, tools=tools)
35
+ return executor
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def agent_node(state, agent, name):
38
  result = agent.invoke(state)
39
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # The agent state is the input to each node in the graph
42
+ class AgentState(TypedDict):
43
+ # The annotation tells the graph that new messages will always
44
+ # be added to the current states
45
+ messages: Annotated[Sequence[BaseMessage], operator.add]
46
+ # The 'next' field indicates where to route to next
47
+ next: str
48
+
49
  def run_multi_agent(prompt):
50
  tavily_tool = TavilySearchResults(max_results=5)
51
  repl = PythonREPL()
52
 
53
+ members = ["Researcher", "Coder"]
54
+ system_prompt = (
55
+ "You are a supervisor tasked with managing a conversation between the"
56
+ " following workers: {members}. Given the following user request,"
57
+ " respond with the worker to act next. Each worker will perform a"
58
+ " task and respond with their results and status. When finished,"
59
+ " respond with FINISH."
60
  )
61
+ # Our team supervisor is an LLM node. It just picks the next agent to process
62
+ # and decides when the work is completed
63
+ options = ["FINISH"] + members
64
+ # Using openai function calling can make output parsing easier for us
65
+ function_def = {
66
+ "name": "route",
67
+ "description": "Select the next role.",
68
+ "parameters": {
69
+ "title": "routeSchema",
70
+ "type": "object",
71
+ "properties": {
72
+ "next": {
73
+ "title": "Next",
74
+ "anyOf": [
75
+ {"enum": options},
76
+ ],
77
+ }
78
+ },
79
+ "required": ["next"],
80
+ },
81
+ }
82
+ prompt = ChatPromptTemplate.from_messages(
83
+ [
84
+ ("system", system_prompt),
85
+ MessagesPlaceholder(variable_name="messages"),
86
+ (
87
+ "system",
88
+ "Given the conversation above, who should act next?"
89
+ " Or should we FINISH? Select one of: {options}",
90
+ ),
91
+ ]
92
+ ).partial(options=str(options), members=", ".join(members))
93
+
94
+ llm = ChatOpenAI(model="gpt-4-1106-preview")
95
+
96
+ supervisor_chain = (
97
+ prompt
98
+ | llm.bind_functions(functions=[function_def], function_call="route")
99
+ | JsonOutputFunctionsParser()
100
+ )
101
+
102
+ research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
103
  research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
104
 
105
+ # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION
106
+ code_agent = create_agent(
107
  llm,
108
+ [python_repl_tool],
109
+ "You may generate safe python code to analyze data and generate charts using matplotlib.",
110
  )
111
+ code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
 
 
 
 
 
112
 
113
+ workflow = StateGraph(AgentState)
114
  workflow.add_node("Researcher", research_node)
115
+ workflow.add_node("Coder", code_node)
116
+ workflow.add_node("supervisor", supervisor_chain)
117
+
118
+ for member in members:
119
+ # We want our workers to ALWAYS "report back" to the supervisor when done
120
+ workflow.add_edge(member, "supervisor")
121
+ # The supervisor populates the "next" field in the graph state
122
+ # which routes to a node or finishes
123
+ conditional_map = {k: k for k in members}
124
+ conditional_map["FINISH"] = END
125
+ workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
126
+ # Finally, add entrypoint
127
+ workflow.set_entry_point("supervisor")
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  graph = workflow.compile()
130
 
131
+ for s in graph.stream(
 
 
 
 
 
 
132
  {
133
  "messages": [
134
+ HumanMessage(content="Code hello world and print it to the terminal")
135
+ ]
136
+ }
137
+ ):
138
+ if "__end__" not in s:
139
+ print(s)
140
+ print("----")
141
+
 
 
 
142
  return "DONE"