multi-agent-ai-langgraph / rag_langgraph.py
bstraehle's picture
Update rag_langgraph.py
3b32f96 verified
raw
history blame
4.54 kB
from typing import Annotated, List, Tuple, Union
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
import operator
from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict
import functools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next: str
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_prompt
),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
agent = create_openai_tools_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
return executor
def agent_node(state, agent, name):
result = agent.invoke(state)
return {"messages": [HumanMessage(content=result["output"], name=name)]}
def create_graph(topic, word_count):
tavily_tool = TavilySearchResults(max_results=10)
members = ["Blogger"]
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
options = ["FINISH"] + members
function_def = {
"name": "route",
"description": "Select the next role.",
"parameters": {
"title": "routeSchema",
"type": "object",
"properties": {
"next": {
"title": "Next",
"anyOf": [
{"enum": options},
],
}
},
"required": ["next"],
},
}
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Given the conversation above, who should act next?"
" Or should we FINISH? Select one of: {options}",
),
]
).partial(options=str(options), members=", ".join(members))
llm = ChatOpenAI(model="gpt-4o")
supervisor_chain = (
prompt
| llm.bind_functions(functions=[function_def], function_call="route")
| JsonOutputFunctionsParser()
)
#research_agent = create_agent(llm, [tavily_tool], f"Research content on topic {topic}")
#research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
blogger_agent = create_agent(llm, [tavily_tool], f"Based on research papers, write a {word_count}-word blog post on topic {topic}. Add a references section.")
blogger_node = functools.partial(agent_node, agent=blogger_agent, name="Blogger")
workflow = StateGraph(AgentState)
#workflow.add_node("Researcher", research_node)
workflow.add_node("Blogger", blogger_node)
workflow.add_node("Manager", supervisor_chain)
for member in members:
workflow.add_edge(member, "Manager")
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("Manager", lambda x: x["next"], conditional_map)
workflow.set_entry_point("Manager")
return workflow.compile()
def run_multi_agent(topic, word_count):
graph = create_graph(topic, word_count)
result = graph.invoke({
"messages": [
HumanMessage(content="Evolution of Retrieval-Augmented Generation from Naive RAG to Agentic RAG")
]
})
print("###")
print(result)
print("###")
print(result['messages'])
print("###")
print(result['messages'][1])
print("###")
print(result['messages'][1].content)
print("###")
return result['messages'][1].content