Lagent / lagent /agents /react.py
Superkingjcj's picture
Upload 111 files
e679d69 verified
raw
history blame
6.1 kB
import json
from typing import Callable, Dict, List, Union
from pydantic import BaseModel, Field
from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction
from lagent.agents.agent import Agent, AsyncAgent
from lagent.agents.aggregator import DefaultAggregator
from lagent.hooks import ActionPreprocessor
from lagent.llms import BaseLLM
from lagent.memory import Memory
from lagent.prompts.parsers.json_parser import JSONParser
from lagent.prompts.prompt_template import PromptTemplate
from lagent.schema import AgentMessage
from lagent.utils import create_object
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
{action_info}
{output_format}
开始!"""
output_format_template = """如果使用工具请遵循以下格式回复:
{function_format}
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
{finish_format}"""
class ReAct(Agent):
def __init__(self,
llm: Union[BaseLLM, Dict],
actions: Union[BaseAction, List[BaseAction]],
template: Union[PromptTemplate, str] = None,
memory: Dict = dict(type=Memory),
output_format: Dict = dict(type=JSONParser),
aggregator: Dict = dict(type=DefaultAggregator),
hooks: List = [dict(type=ActionPreprocessor)],
finish_condition: Callable[[AgentMessage], bool] = lambda m:
'conclusion' in m.content or 'conclusion' in m.formatted,
max_turn: int = 5,
**kwargs):
self.max_turn = max_turn
self.finish_condition = finish_condition
actions = dict(
type=ActionExecutor,
actions=actions,
hooks=hooks,
)
self.actions: ActionExecutor = create_object(actions)
select_agent = dict(
type=Agent,
llm=llm,
template=template.format(
action_info=json.dumps(self.actions.description()),
output_format=output_format.format_instruction()),
output_format=output_format,
memory=memory,
aggregator=aggregator,
hooks=hooks,
)
self.select_agent = create_object(select_agent)
super().__init__(**kwargs)
def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
for _ in range(self.max_turn):
message = self.select_agent(message)
if self.finish_condition(message):
return message
message = self.actions(message)
return message
class AsyncReAct(AsyncAgent):
def __init__(self,
llm: Union[BaseLLM, Dict],
actions: Union[BaseAction, List[BaseAction]],
template: Union[PromptTemplate, str] = None,
memory: Dict = dict(type=Memory),
output_format: Dict = dict(type=JSONParser),
aggregator: Dict = dict(type=DefaultAggregator),
hooks: List = [dict(type=ActionPreprocessor)],
finish_condition: Callable[[AgentMessage], bool] = lambda m:
'conclusion' in m.content or 'conclusion' in m.formatted,
max_turn: int = 5,
**kwargs):
self.max_turn = max_turn
self.finish_condition = finish_condition
actions = dict(
type=AsyncActionExecutor,
actions=actions,
hooks=hooks,
)
self.actions: AsyncActionExecutor = create_object(actions)
select_agent = dict(
type=AsyncAgent,
llm=llm,
template=template.format(
action_info=json.dumps(self.actions.description()),
output_format=output_format.format_instruction()),
output_format=output_format,
memory=memory,
aggregator=aggregator,
hooks=hooks,
)
self.select_agent = create_object(select_agent)
super().__init__(**kwargs)
async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
for _ in range(self.max_turn):
message = await self.select_agent(message)
if self.finish_condition(message):
return message
message = await self.actions(message)
return message
if __name__ == '__main__':
from lagent.llms import GPTAPI
class ActionCall(BaseModel):
name: str = Field(description='调用的函数名称')
parameters: Dict = Field(description='调用函数的参数')
class ActionFormat(BaseModel):
thought_process: str = Field(
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')
class FinishFormat(BaseModel):
thought_process: str = Field(
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
prompt_template = PromptTemplate(select_action_template)
output_format = JSONParser(
output_format_template,
function_format=ActionFormat,
finish_format=FinishFormat)
llm = dict(
type=GPTAPI,
model_type='gpt-4o-2024-05-13',
key=None,
max_new_tokens=4096,
proxies=dict(),
retry=1000)
agent = ReAct(
llm=llm,
template=prompt_template,
output_format=output_format,
aggregator=dict(type='DefaultAggregator'),
actions=[dict(type='PythonInterpreter')],
)
response = agent(
AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
print(response)
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
print(response)