rise-ai / agent /_create.py
markpeace's picture
Now checks for classificatory credits
33b89d4
raw
history blame
3.33 kB
def agent(payload):
DEBUG=True
from agent.memory import Memory
memory = Memory(payload)
from agent.jsonencoder import json_parse_chain
from agent.agent_main import Chain_Main_Agent
chain_main_agent = Chain_Main_Agent(memory)
from agent.toolset import tool_executor, converted_tools
from langgraph.prebuilt import ToolInvocation
import json
from langchain_core.messages import FunctionMessage
def call_main_agent(messages):
response = chain_main_agent.invoke({"conversation":messages, "thread_id": memory.thread_id})
if DEBUG: print("call_main_agent called");
return response
def use_tool(messages):
last_message = messages[-1]
action = ToolInvocation(
tool=last_message.additional_kwargs["function_call"]["name"],
tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]),
)
response = tool_executor.invoke(action)
function_message = FunctionMessage(content=str(response), name=action.tool)
if DEBUG: print("Suggesting Tool to use..."+action.tool);
return function_message
def render_output(messages):
import json
response = json_parse_chain.invoke({"conversation":messages, "thread_id": memory.thread_id})
if DEBUG: print("Rendering output");
from langchain_core.messages import AIMessage
response = json.dumps(response)
return AIMessage(content=response)
from langgraph.graph import MessageGraph, END
workflow = MessageGraph()
workflow.add_node("main_agent", call_main_agent)
workflow.add_node("use_tool", use_tool)
workflow.add_node("render_output", render_output)
workflow.set_entry_point("main_agent")
def should_continue(messages):
last_message = messages[-1]
if "function_call" not in last_message.additional_kwargs: return "render_output"
else: return "continue"
workflow.add_conditional_edges(
"main_agent", should_continue,
{
"continue": "use_tool",
"render_output":"render_output",
"end": END
}
)
workflow.add_edge('use_tool', 'main_agent')
workflow.add_edge('render_output', END)
app = workflow.compile(checkpointer=memory.checkpoints)
from langchain_core.messages import HumanMessage
input = payload.get("input") or "Can I earn credit?"
inputs = [HumanMessage(content=input)]
response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
'''
inputs = [HumanMessage(content="My name is Mark")]
response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
print(response[-1].content)
inputs = [HumanMessage(content="What is my name?")]
response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
print(response[-1].content)
'''
print(response[-1].content)
return response[-1].content