|
import os |
|
|
|
from langchain_core.messages import ( |
|
BaseMessage, |
|
ToolMessage, |
|
HumanMessage, |
|
) |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langgraph.graph import END, StateGraph |
|
|
|
def create_agent(llm, tools, system_message: str): |
|
"""Create an agent.""" |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"You are a helpful AI assistant, collaborating with other assistants." |
|
" Use the provided tools to progress towards answering the question." |
|
" If you are unable to fully answer, that's OK, another assistant with different tools " |
|
" will help where you left off. Execute what you can to make progress." |
|
" If you or any of the other assistants have the final answer or deliverable," |
|
" prefix your response with FINAL ANSWER so the team knows to stop." |
|
" You have access to the following tools: {tool_names}.\n{system_message}", |
|
), |
|
MessagesPlaceholder(variable_name="messages"), |
|
] |
|
) |
|
prompt = prompt.partial(system_message=system_message) |
|
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) |
|
return prompt | llm.bind_tools(tools) |
|
|
|
from langchain_core.tools import tool |
|
from typing import Annotated |
|
from langchain_experimental.utilities import PythonREPL |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
|
tavily_tool = TavilySearchResults(max_results=5) |
|
|
|
|
|
|
|
repl = PythonREPL() |
|
|
|
@tool |
|
def python_repl( |
|
code: Annotated[str, "The python code to execute to generate your chart."] |
|
): |
|
"""Use this to execute python code. If you want to see the output of a value, |
|
you should print it out with `print(...)`. This is visible to the user.""" |
|
try: |
|
result = repl.run(code) |
|
except BaseException as e: |
|
return f"Failed to execute. Error: {repr(e)}" |
|
result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" |
|
return ( |
|
result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER." |
|
) |
|
|
|
import operator |
|
from typing import Annotated, Sequence, TypedDict |
|
|
|
from langchain_openai import ChatOpenAI |
|
from typing_extensions import TypedDict |
|
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
sender: str |
|
|
|
import functools |
|
from langchain_core.messages import AIMessage |
|
|
|
|
|
|
|
def agent_node(state, agent, name): |
|
result = agent.invoke(state) |
|
|
|
if isinstance(result, ToolMessage): |
|
pass |
|
else: |
|
result = AIMessage(**result.dict(exclude={"type", "name"}), name=name) |
|
return { |
|
"messages": [result], |
|
|
|
|
|
"sender": name, |
|
} |
|
|
|
llm = ChatOpenAI(model="gpt-4-1106-preview") |
|
|
|
|
|
research_agent = create_agent( |
|
llm, |
|
[tavily_tool], |
|
system_message="You should provide accurate data for the chart_generator to use.", |
|
) |
|
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") |
|
|
|
|
|
chart_agent = create_agent( |
|
llm, |
|
[python_repl], |
|
system_message="Any charts you display will be visible by the user.", |
|
) |
|
chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator") |
|
|
|
from langgraph.prebuilt import ToolNode |
|
|
|
tools = [tavily_tool, python_repl] |
|
tool_node = ToolNode(tools) |
|
|
|
|
|
from typing import Literal |
|
|
|
def router(state) -> Literal["call_tool", "__end__", "continue"]: |
|
|
|
messages = state["messages"] |
|
last_message = messages[-1] |
|
if last_message.tool_calls: |
|
|
|
return "call_tool" |
|
if "FINAL ANSWER" in last_message.content: |
|
|
|
return "__end__" |
|
return "continue" |
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
workflow.add_node("Researcher", research_node) |
|
workflow.add_node("chart_generator", chart_node) |
|
workflow.add_node("call_tool", tool_node) |
|
|
|
workflow.add_conditional_edges( |
|
"Researcher", |
|
router, |
|
{"continue": "chart_generator", "call_tool": "call_tool", "__end__": END}, |
|
) |
|
workflow.add_conditional_edges( |
|
"chart_generator", |
|
router, |
|
{"continue": "Researcher", "call_tool": "call_tool", "__end__": END}, |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"call_tool", |
|
|
|
|
|
|
|
|
|
lambda x: x["sender"], |
|
{ |
|
"Researcher": "Researcher", |
|
"chart_generator": "chart_generator", |
|
}, |
|
) |
|
workflow.set_entry_point("Researcher") |
|
graph = workflow.compile() |
|
|
|
from IPython.display import Image, display |
|
|
|
try: |
|
display(Image(graph.get_graph(xray=True).draw_mermaid_png())) |
|
except: |
|
|
|
pass |
|
|
|
events = graph.stream( |
|
{ |
|
"messages": [ |
|
HumanMessage( |
|
content="Fetch the UK's GDP over the past 5 years," |
|
" then draw a line graph of it." |
|
" Once you code it up, finish." |
|
) |
|
], |
|
}, |
|
|
|
{"recursion_limit": 150}, |
|
) |
|
for s in events: |
|
print(s) |
|
print("----") |
|
|
|
def run_multi_agent(): |
|
return "DONE" |