bstraehle commited on
Commit
5c45105
1 Parent(s): 6dc2383

Create langgraph

Browse files
Files changed (1) hide show
  1. langgraph +184 -0
langgraph ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_core.messages import (
4
+ BaseMessage,
5
+ ToolMessage,
6
+ HumanMessage,
7
+ )
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from langgraph.graph import END, StateGraph
10
+
11
+
12
+ def create_agent(llm, tools, system_message: str):
13
+ """Create an agent."""
14
+ prompt = ChatPromptTemplate.from_messages(
15
+ [
16
+ (
17
+ "system",
18
+ "You are a helpful AI assistant, collaborating with other assistants."
19
+ " Use the provided tools to progress towards answering the question."
20
+ " If you are unable to fully answer, that's OK, another assistant with different tools "
21
+ " will help where you left off. Execute what you can to make progress."
22
+ " If you or any of the other assistants have the final answer or deliverable,"
23
+ " prefix your response with FINAL ANSWER so the team knows to stop."
24
+ " You have access to the following tools: {tool_names}.\n{system_message}",
25
+ ),
26
+ MessagesPlaceholder(variable_name="messages"),
27
+ ]
28
+ )
29
+ prompt = prompt.partial(system_message=system_message)
30
+ prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
31
+ return prompt | llm.bind_tools(tools)
32
+
33
+ from langchain_core.tools import tool
34
+ from typing import Annotated
35
+ from langchain_experimental.utilities import PythonREPL
36
+ from langchain_community.tools.tavily_search import TavilySearchResults
37
+
38
+ tavily_tool = TavilySearchResults(max_results=5)
39
+
40
+ # Warning: This executes code locally, which can be unsafe when not sandboxed
41
+
42
+ repl = PythonREPL()
43
+
44
+ @tool
45
+ def python_repl(
46
+ code: Annotated[str, "The python code to execute to generate your chart."]
47
+ ):
48
+ """Use this to execute python code. If you want to see the output of a value,
49
+ you should print it out with `print(...)`. This is visible to the user."""
50
+ try:
51
+ result = repl.run(code)
52
+ except BaseException as e:
53
+ return f"Failed to execute. Error: {repr(e)}"
54
+ result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
55
+ return (
56
+ result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
57
+ )
58
+
59
+ import operator
60
+ from typing import Annotated, Sequence, TypedDict
61
+
62
+ from langchain_openai import ChatOpenAI
63
+ from typing_extensions import TypedDict
64
+
65
+
66
+ # This defines the object that is passed between each node
67
+ # in the graph. We will create different nodes for each agent and tool
68
+ class AgentState(TypedDict):
69
+ messages: Annotated[Sequence[BaseMessage], operator.add]
70
+ sender: str
71
+
72
+ import functools
73
+ from langchain_core.messages import AIMessage
74
+
75
+
76
+ # Helper function to create a node for a given agent
77
+ def agent_node(state, agent, name):
78
+ result = agent.invoke(state)
79
+ # We convert the agent output into a format that is suitable to append to the global state
80
+ if isinstance(result, ToolMessage):
81
+ pass
82
+ else:
83
+ result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
84
+ return {
85
+ "messages": [result],
86
+ # Since we have a strict workflow, we can
87
+ # track the sender so we know who to pass to next.
88
+ "sender": name,
89
+ }
90
+
91
+ llm = ChatOpenAI(model="gpt-4-1106-preview")
92
+
93
+ # Research agent and node
94
+ research_agent = create_agent(
95
+ llm,
96
+ [tavily_tool],
97
+ system_message="You should provide accurate data for the chart_generator to use.",
98
+ )
99
+ research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
100
+
101
+ # chart_generator
102
+ chart_agent = create_agent(
103
+ llm,
104
+ [python_repl],
105
+ system_message="Any charts you display will be visible by the user.",
106
+ )
107
+ chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")
108
+
109
+ from langgraph.prebuilt import ToolNode
110
+
111
+ tools = [tavily_tool, python_repl]
112
+ tool_node = ToolNode(tools)
113
+
114
+ # Either agent can decide to end
115
+ from typing import Literal
116
+
117
+ def router(state) -> Literal["call_tool", "__end__", "continue"]:
118
+ # This is the router
119
+ messages = state["messages"]
120
+ last_message = messages[-1]
121
+ if last_message.tool_calls:
122
+ # The previous agent is invoking a tool
123
+ return "call_tool"
124
+ if "FINAL ANSWER" in last_message.content:
125
+ # Any agent decided the work is done
126
+ return "__end__"
127
+ return "continue"
128
+
129
+ workflow = StateGraph(AgentState)
130
+
131
+ workflow.add_node("Researcher", research_node)
132
+ workflow.add_node("chart_generator", chart_node)
133
+ workflow.add_node("call_tool", tool_node)
134
+
135
+ workflow.add_conditional_edges(
136
+ "Researcher",
137
+ router,
138
+ {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
139
+ )
140
+ workflow.add_conditional_edges(
141
+ "chart_generator",
142
+ router,
143
+ {"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
144
+ )
145
+
146
+ workflow.add_conditional_edges(
147
+ "call_tool",
148
+ # Each agent node updates the 'sender' field
149
+ # the tool calling node does not, meaning
150
+ # this edge will route back to the original agent
151
+ # who invoked the tool
152
+ lambda x: x["sender"],
153
+ {
154
+ "Researcher": "Researcher",
155
+ "chart_generator": "chart_generator",
156
+ },
157
+ )
158
+ workflow.set_entry_point("Researcher")
159
+ graph = workflow.compile()
160
+
161
+ from IPython.display import Image, display
162
+
163
+ try:
164
+ display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
165
+ except:
166
+ # This requires some extra dependencies and is optional
167
+ pass
168
+
169
+ events = graph.stream(
170
+ {
171
+ "messages": [
172
+ HumanMessage(
173
+ content="Fetch the UK's GDP over the past 5 years,"
174
+ " then draw a line graph of it."
175
+ " Once you code it up, finish."
176
+ )
177
+ ],
178
+ },
179
+ # Maximum number of steps to take in the graph
180
+ {"recursion_limit": 150},
181
+ )
182
+ for s in events:
183
+ print(s)
184
+ print("----")