bstraehle commited on
Commit
da7c75f
1 Parent(s): d512015

Update rag_langgraph.py

Browse files
Files changed (1) hide show
  1. rag_langgraph.py +8 -4
rag_langgraph.py CHANGED
@@ -42,7 +42,7 @@ def agent_node(state, agent, name):
42
  def create_graph(topic, word_count):
43
  tavily_tool = TavilySearchResults(max_results=10)
44
 
45
- members = ["Researcher"]
46
 
47
  system_prompt = (
48
  "You are a manager tasked with managing a conversation between the"
@@ -92,11 +92,15 @@ def create_graph(topic, word_count):
92
  | JsonOutputFunctionsParser()
93
  )
94
 
95
- research_agent = create_agent(llm, [tavily_tool], system_prompt=f"Research content on topic {topic}, prioritizing research papers. Then write a {word_count}-word article on topic {topic}. Add a references section with research papers.")
96
- research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
97
-
 
 
 
98
  workflow = StateGraph(AgentState)
99
  workflow.add_node("Researcher", research_node)
 
100
  workflow.add_node("Manager", supervisor_chain)
101
 
102
  for member in members:
 
42
  def create_graph(topic, word_count):
43
  tavily_tool = TavilySearchResults(max_results=10)
44
 
45
+ members = ["Researcher", "Writer"]
46
 
47
  system_prompt = (
48
  "You are a manager tasked with managing a conversation between the"
 
92
  | JsonOutputFunctionsParser()
93
  )
94
 
95
+ researcher_agent = create_agent(llm, [tavily_tool], system_prompt=f"Prioritizing research papers, research content on topic: {topic}.")
96
+ researcher_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher")
97
+
98
+ writer_agent = create_agent(llm, [None], system_prompt=f"Write a {word_count}-word article on topic: {topic}. Add a references section with research papers.")
99
+ writer_node = functools.partial(agent_node, agent=writer_agent, name="Writer")
100
+
101
  workflow = StateGraph(AgentState)
102
  workflow.add_node("Researcher", research_node)
103
+ workflow.add_node("Writer", research_node)
104
  workflow.add_node("Manager", supervisor_chain)
105
 
106
  for member in members: