|
import gradio as gr |
|
import getpass |
|
import os |
|
|
|
def _set_if_undefined(var: str): |
|
if not os.environ.get(var): |
|
os.environ[var] = getpass.getpass(f"Please provide your {var}") |
|
|
|
_set_if_undefined("OPENAI_API_KEY") |
|
_set_if_undefined("LANGCHAIN_API_KEY") |
|
_set_if_undefined("TAVILY_API_KEY") |
|
|
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_PROJECT"] = "Multi-agent Collaboration" |
|
|
|
from typing import Annotated, List, Tuple, Union |
|
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.tools import tool |
|
from langchain_experimental.tools import PythonREPLTool |
|
|
|
tavily_tool = TavilySearchResults(max_results=5) |
|
|
|
|
|
python_repl_tool = PythonREPLTool() |
|
|
|
from langchain.agents import AgentExecutor, create_openai_tools_agent |
|
from langchain_core.messages import BaseMessage, HumanMessage |
|
from langchain_openai import ChatOpenAI |
|
|
|
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
system_prompt, |
|
), |
|
MessagesPlaceholder(variable_name="messages"), |
|
MessagesPlaceholder(variable_name="agent_scratchpad"), |
|
] |
|
) |
|
agent = create_openai_tools_agent(llm, tools, prompt) |
|
executor = AgentExecutor(agent=agent, tools=tools) |
|
return executor |
|
|
|
def agent_node(state, agent, name): |
|
result = agent.invoke(state) |
|
return {"messages": [HumanMessage(content=result["output"], name=name)]} |
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser |
|
|
|
members = ["Researcher", "Coder"] |
|
system_prompt = ( |
|
"You are a supervisor tasked with managing a conversation between the" |
|
" following workers: {members}. Given the following user request," |
|
" respond with the worker to act next. Each worker will perform a" |
|
" task and respond with their results and status. When finished," |
|
" respond with FINISH." |
|
) |
|
|
|
|
|
options = ["FINISH"] + members |
|
|
|
function_def = { |
|
"name": "route", |
|
"description": "Select the next role.", |
|
"parameters": { |
|
"title": "routeSchema", |
|
"type": "object", |
|
"properties": { |
|
"next": { |
|
"title": "Next", |
|
"anyOf": [ |
|
{"enum": options}, |
|
], |
|
} |
|
}, |
|
"required": ["next"], |
|
}, |
|
} |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_prompt), |
|
MessagesPlaceholder(variable_name="messages"), |
|
( |
|
"system", |
|
"Given the conversation above, who should act next?" |
|
" Or should we FINISH? Select one of: {options}", |
|
), |
|
] |
|
).partial(options=str(options), members=", ".join(members)) |
|
|
|
llm = ChatOpenAI(model="gpt-4-1106-preview") |
|
|
|
supervisor_chain = ( |
|
prompt |
|
| llm.bind_functions(functions=[function_def], function_call="route") |
|
| JsonOutputFunctionsParser() |
|
) |
|
|
|
import operator |
|
from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict |
|
import functools |
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
|
|
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
|
next: str |
|
|
|
|
|
research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.") |
|
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") |
|
|
|
|
|
code_agent = create_agent( |
|
llm, |
|
[python_repl_tool], |
|
"You may generate safe python code to analyze data and generate charts using matplotlib.", |
|
) |
|
code_node = functools.partial(agent_node, agent=code_agent, name="Coder") |
|
|
|
workflow = StateGraph(AgentState) |
|
workflow.add_node("Researcher", research_node) |
|
workflow.add_node("Coder", code_node) |
|
workflow.add_node("supervisor", supervisor_chain) |
|
|
|
for member in members: |
|
|
|
workflow.add_edge(member, "supervisor") |
|
|
|
|
|
conditional_map = {k: k for k in members} |
|
conditional_map["FINISH"] = END |
|
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) |
|
|
|
workflow.set_entry_point("supervisor") |
|
|
|
graph = workflow.compile() |
|
|
|
|
|
|
|
def invoke(openai_api_key, topic, word_count=500): |
|
if (openai_api_key == ""): |
|
raise gr.Error("OpenAI API Key is required.") |
|
if (topic == ""): |
|
raise gr.Error("Topic is required.") |
|
|
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = openai_api_key |
|
|
|
for s in graph.stream( |
|
{ |
|
"messages": [ |
|
HumanMessage(content="Code hello world and print it to the terminal") |
|
] |
|
} |
|
): |
|
if "__end__" not in s: |
|
print(s) |
|
print("----") |
|
|
|
return result |
|
|
|
gr.close_all() |
|
|
|
demo = gr.Interface(fn = invoke, |
|
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), |
|
gr.Textbox(label = "Topic", value="TODO", lines = 1), |
|
gr.Number(label = "Word Count", value=1000, minimum=500, maximum=5000)], |
|
outputs = [gr.Markdown(label = "Generated Blog Post", value="TODO")], |
|
title = "Multi-Agent RAG: Blog Post Generation", |
|
description = "TODO") |
|
|
|
demo.launch() |