bstraehle commited on
Commit
6efa705
1 Parent(s): 06d9c7d

Update rag_langgraph.py

Browse files
Files changed (1) hide show
  1. rag_langgraph.py +11 -9
rag_langgraph.py CHANGED
@@ -14,8 +14,6 @@ from langchain_openai import ChatOpenAI
14
 
15
  from langgraph.graph import StateGraph, END
16
 
17
- LLM = "gpt-4o"
18
-
19
  class AgentState(TypedDict):
20
  messages: Annotated[Sequence[BaseMessage], operator.add]
21
  next: str
@@ -43,7 +41,7 @@ def today_tool(text: str) -> str:
43
  Any date mathematics should occur outside this function."""
44
  return (str(date.today()) + "\n\nIf you have completed all tasks, respond with FINAL ANSWER.")
45
 
46
- def create_graph(topic):
47
  tavily_tool = TavilySearchResults(max_results=10)
48
 
49
  members = ["Researcher"]
@@ -88,7 +86,7 @@ def create_graph(topic):
88
  ]
89
  ).partial(options=str(options), members=", ".join(members))
90
 
91
- llm = ChatOpenAI(model=LLM)
92
 
93
  supervisor_chain = (
94
  prompt
@@ -117,15 +115,19 @@ def create_graph(topic):
117
 
118
  return workflow.compile()
119
 
120
- def run_multi_agent(topic):
121
- graph = create_graph(topic)
 
122
  result = graph.invoke({
123
  "messages": [
124
  HumanMessage(content=topic)
125
  ]
126
  })
 
127
  article = result['messages'][-1].content
128
- #print("***")
129
- #print(article)
130
- #print("***")
 
 
131
  return article
 
14
 
15
  from langgraph.graph import StateGraph, END
16
 
 
 
17
  class AgentState(TypedDict):
18
  messages: Annotated[Sequence[BaseMessage], operator.add]
19
  next: str
 
41
  Any date mathematics should occur outside this function."""
42
  return (str(date.today()) + "\n\nIf you have completed all tasks, respond with FINAL ANSWER.")
43
 
44
+ def create_graph(model, topic):
45
  tavily_tool = TavilySearchResults(max_results=10)
46
 
47
  members = ["Researcher"]
 
86
  ]
87
  ).partial(options=str(options), members=", ".join(members))
88
 
89
+ llm = ChatOpenAI(model=model)
90
 
91
  supervisor_chain = (
92
  prompt
 
115
 
116
  return workflow.compile()
117
 
118
+ def run_multi_agent(model, topic):
119
+ graph = create_graph(model, topic)
120
+
121
  result = graph.invoke({
122
  "messages": [
123
  HumanMessage(content=topic)
124
  ]
125
  })
126
+
127
  article = result['messages'][-1].content
128
+
129
+ print("***")
130
+ print(article)
131
+ print("***")
132
+
133
  return article