bstraehle commited on
Commit
bbaff74
1 Parent(s): 4b76b70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -123
app.py CHANGED
@@ -14,172 +14,207 @@ _set_if_undefined("TAVILY_API_KEY")
14
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
15
  os.environ["LANGCHAIN_PROJECT"] = "Multi-agent Collaboration"
16
 
17
- from typing import Annotated, List, Tuple, Union
18
-
19
- from langchain_community.tools.tavily_search import TavilySearchResults
20
- from langchain_core.tools import tool
21
- from langchain_experimental.tools import PythonREPLTool
22
-
23
- tavily_tool = TavilySearchResults(max_results=5)
24
-
25
- # This executes code locally, which can be unsafe
26
- python_repl_tool = PythonREPLTool()
27
 
28
- from langchain.agents import AgentExecutor, create_openai_tools_agent
29
- from langchain_core.messages import BaseMessage, HumanMessage
30
- from langchain_openai import ChatOpenAI
31
 
32
- def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
33
- # Each worker node will be given a name and some tools.
34
  prompt = ChatPromptTemplate.from_messages(
35
  [
36
  (
37
  "system",
38
- system_prompt,
 
 
 
 
 
 
39
  ),
40
  MessagesPlaceholder(variable_name="messages"),
41
- MessagesPlaceholder(variable_name="agent_scratchpad"),
42
  ]
43
  )
44
- agent = create_openai_tools_agent(llm, tools, prompt)
45
- executor = AgentExecutor(agent=agent, tools=tools)
46
- return executor
47
-
48
- def agent_node(state, agent, name):
49
- result = agent.invoke(state)
50
- return {"messages": [HumanMessage(content=result["output"], name=name)]}
51
 
52
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
53
- from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
54
-
55
- members = ["Researcher", "Coder"]
56
- system_prompt = (
57
- "You are a supervisor tasked with managing a conversation between the"
58
- " following workers: {members}. Given the following user request,"
59
- " respond with the worker to act next. Each worker will perform a"
60
- " task and respond with their results and status. When finished,"
61
- " respond with FINISH."
62
- )
63
- # Our team supervisor is an LLM node. It just picks the next agent to process
64
- # and decides when the work is completed
65
- options = ["FINISH"] + members
66
- # Using openai function calling can make output parsing easier for us
67
- function_def = {
68
- "name": "route",
69
- "description": "Select the next role.",
70
- "parameters": {
71
- "title": "routeSchema",
72
- "type": "object",
73
- "properties": {
74
- "next": {
75
- "title": "Next",
76
- "anyOf": [
77
- {"enum": options},
78
- ],
79
- }
80
- },
81
- "required": ["next"],
82
- },
83
- }
84
- prompt = ChatPromptTemplate.from_messages(
85
- [
86
- ("system", system_prompt),
87
- MessagesPlaceholder(variable_name="messages"),
88
- (
89
- "system",
90
- "Given the conversation above, who should act next?"
91
- " Or should we FINISH? Select one of: {options}",
92
- ),
93
- ]
94
- ).partial(options=str(options), members=", ".join(members))
95
 
96
- llm = ChatOpenAI(model="gpt-4-1106-preview")
97
 
98
- supervisor_chain = (
99
- prompt
100
- | llm.bind_functions(functions=[function_def], function_call="route")
101
- | JsonOutputFunctionsParser()
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  import operator
105
- from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict
106
- import functools
107
 
108
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
109
- from langgraph.graph import StateGraph, END
110
 
111
 
112
- # The agent state is the input to each node in the graph
 
113
  class AgentState(TypedDict):
114
- # The annotation tells the graph that new messages will always
115
- # be added to the current states
116
  messages: Annotated[Sequence[BaseMessage], operator.add]
117
- # The 'next' field indicates where to route to next
118
- next: str
119
 
 
 
120
 
121
- research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
123
 
124
- # NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION
125
- code_agent = create_agent(
126
  llm,
127
- [python_repl_tool],
128
- "You may generate safe python code to analyze data and generate charts using matplotlib.",
129
  )
130
- code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  workflow = StateGraph(AgentState)
 
133
  workflow.add_node("Researcher", research_node)
134
- workflow.add_node("Coder", code_node)
135
- workflow.add_node("supervisor", supervisor_chain)
136
-
137
- for member in members:
138
- # We want our workers to ALWAYS "report back" to the supervisor when done
139
- workflow.add_edge(member, "supervisor")
140
- # The supervisor populates the "next" field in the graph state
141
- # which routes to a node or finishes
142
- conditional_map = {k: k for k in members}
143
- conditional_map["FINISH"] = END
144
- workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
145
- # Finally, add entrypoint
146
- workflow.set_entry_point("supervisor")
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  graph = workflow.compile()
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ###
151
 
152
- def invoke(openai_api_key, topic, word_count=500):
153
  if (openai_api_key == ""):
154
  raise gr.Error("OpenAI API Key is required.")
155
- if (topic == ""):
156
- raise gr.Error("Topic is required.")
157
 
158
- #agentops.init(os.environ["AGENTOPS_API_KEY"])
159
-
160
  os.environ["OPENAI_API_KEY"] = openai_api_key
161
-
162
- for s in graph.stream(
163
- {
164
- "messages": [
165
- HumanMessage(content="Code hello world and print it to the terminal")
166
- ]
167
- }
168
- ):
169
- if "__end__" not in s:
170
- print(s)
171
- print("----")
172
-
173
- return result
174
 
175
  gr.close_all()
176
 
177
  demo = gr.Interface(fn = invoke,
178
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
179
- gr.Textbox(label = "Topic", value="TODO", lines = 1),
180
- gr.Number(label = "Word Count", value=1000, minimum=500, maximum=5000)],
181
- outputs = [gr.Markdown(label = "Generated Blog Post", value="TODO")],
182
- title = "Multi-Agent RAG: Blog Post Generation",
183
  description = "TODO")
184
 
185
  demo.launch()
 
14
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
15
  os.environ["LANGCHAIN_PROJECT"] = "Multi-agent Collaboration"
16
 
17
+ from langchain_core.messages import (
18
+ BaseMessage,
19
+ ToolMessage,
20
+ HumanMessage,
21
+ )
22
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
23
+ from langgraph.graph import END, StateGraph
 
 
 
24
 
 
 
 
25
 
26
+ def create_agent(llm, tools, system_message: str):
27
+ """Create an agent."""
28
  prompt = ChatPromptTemplate.from_messages(
29
  [
30
  (
31
  "system",
32
+ "You are a helpful AI assistant, collaborating with other assistants."
33
+ " Use the provided tools to progress towards answering the question."
34
+ " If you are unable to fully answer, that's OK, another assistant with different tools "
35
+ " will help where you left off. Execute what you can to make progress."
36
+ " If you or any of the other assistants have the final answer or deliverable,"
37
+ " prefix your response with FINAL ANSWER so the team knows to stop."
38
+ " You have access to the following tools: {tool_names}.\n{system_message}",
39
  ),
40
  MessagesPlaceholder(variable_name="messages"),
 
41
  ]
42
  )
43
+ prompt = prompt.partial(system_message=system_message)
44
+ prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
45
+ return prompt | llm.bind_tools(tools)
 
 
 
 
46
 
47
+ from langchain_core.tools import tool
48
+ from typing import Annotated
49
+ from langchain_experimental.utilities import PythonREPL
50
+ from langchain_community.tools.tavily_search import TavilySearchResults
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ tavily_tool = TavilySearchResults(max_results=5)
53
 
54
+ # Warning: This executes code locally, which can be unsafe when not sandboxed
55
+
56
+ repl = PythonREPL()
57
+
58
+ @tool
59
+ def python_repl(
60
+ code: Annotated[str, "The python code to execute to generate your chart."]
61
+ ):
62
+ """Use this to execute python code. If you want to see the output of a value,
63
+ you should print it out with `print(...)`. This is visible to the user."""
64
+ try:
65
+ result = repl.run(code)
66
+ except BaseException as e:
67
+ return f"Failed to execute. Error: {repr(e)}"
68
+ result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
69
+ return (
70
+ result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
71
+ )
72
 
73
  import operator
74
+ from typing import Annotated, Sequence, TypedDict
 
75
 
76
+ from langchain_openai import ChatOpenAI
77
+ from typing_extensions import TypedDict
78
 
79
 
80
+ # This defines the object that is passed between each node
81
+ # in the graph. We will create different nodes for each agent and tool
82
  class AgentState(TypedDict):
 
 
83
  messages: Annotated[Sequence[BaseMessage], operator.add]
84
+ sender: str
 
85
 
86
+ import functools
87
+ from langchain_core.messages import AIMessage
88
 
89
+
90
+ # Helper function to create a node for a given agent
91
+ def agent_node(state, agent, name):
92
+ result = agent.invoke(state)
93
+ # We convert the agent output into a format that is suitable to append to the global state
94
+ if isinstance(result, ToolMessage):
95
+ pass
96
+ else:
97
+ result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
98
+ return {
99
+ "messages": [result],
100
+ # Since we have a strict workflow, we can
101
+ # track the sender so we know who to pass to next.
102
+ "sender": name,
103
+ }
104
+
105
+
106
+ llm = ChatOpenAI(model="gpt-4-1106-preview")
107
+
108
+ # Research agent and node
109
+ research_agent = create_agent(
110
+ llm,
111
+ [tavily_tool],
112
+ system_message="You should provide accurate data for the chart_generator to use.",
113
+ )
114
  research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
115
 
116
+ # chart_generator
117
+ chart_agent = create_agent(
118
  llm,
119
+ [python_repl],
120
+ system_message="Any charts you display will be visible by the user.",
121
  )
122
+ chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")
123
+
124
+ from langgraph.prebuilt import ToolNode
125
+
126
+ tools = [tavily_tool, python_repl]
127
+ tool_node = ToolNode(tools)
128
+
129
+ # Either agent can decide to end
130
+ from typing import Literal
131
+
132
+ def router(state) -> Literal["call_tool", "__end__", "continue"]:
133
+ # This is the router
134
+ messages = state["messages"]
135
+ last_message = messages[-1]
136
+ if last_message.tool_calls:
137
+ # The previous agent is invoking a tool
138
+ return "call_tool"
139
+ if "FINAL ANSWER" in last_message.content:
140
+ # Any agent decided the work is done
141
+ return "__end__"
142
+ return "continue"
143
 
144
  workflow = StateGraph(AgentState)
145
+
146
  workflow.add_node("Researcher", research_node)
147
+ workflow.add_node("chart_generator", chart_node)
148
+ workflow.add_node("call_tool", tool_node)
149
+
150
+ workflow.add_conditional_edges(
151
+ "Researcher",
152
+ router,
153
+ {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
154
+ )
155
+ workflow.add_conditional_edges(
156
+ "chart_generator",
157
+ router,
158
+ {"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
159
+ )
160
 
161
+ workflow.add_conditional_edges(
162
+ "call_tool",
163
+ # Each agent node updates the 'sender' field
164
+ # the tool calling node does not, meaning
165
+ # this edge will route back to the original agent
166
+ # who invoked the tool
167
+ lambda x: x["sender"],
168
+ {
169
+ "Researcher": "Researcher",
170
+ "chart_generator": "chart_generator",
171
+ },
172
+ )
173
+ workflow.set_entry_point("Researcher")
174
  graph = workflow.compile()
175
 
176
+ from IPython.display import Image, display
177
+
178
+ try:
179
+ display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
180
+ except:
181
+ # This requires some extra dependencies and is optional
182
+ pass
183
+
184
+ events = graph.stream(
185
+ {
186
+ "messages": [
187
+ HumanMessage(
188
+ content="Fetch the UK's GDP over the past 5 years,"
189
+ " then draw a line graph of it."
190
+ " Once you code it up, finish."
191
+ )
192
+ ],
193
+ },
194
+ # Maximum number of steps to take in the graph
195
+ {"recursion_limit": 150},
196
+ )
197
+ for s in events:
198
+ print(s)
199
+ print("----")
200
+
201
  ###
202
 
203
+ def invoke(openai_api_key):
204
  if (openai_api_key == ""):
205
  raise gr.Error("OpenAI API Key is required.")
 
 
206
 
 
 
207
  os.environ["OPENAI_API_KEY"] = openai_api_key
208
+
209
+ return "TODO"
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  gr.close_all()
212
 
213
  demo = gr.Interface(fn = invoke,
214
  inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
215
+ gr.Textbox(label = "TODO", value="TODO", lines = 1)],
216
+ outputs = [gr.Markdown(label = "TODO", value="TODO")],
217
+ title = "Multi-Agent RAG: Chart Generation",
 
218
  description = "TODO")
219
 
220
  demo.launch()