Spaces:
Running
Running
import json | |
from typing import Callable, Dict, List, Union | |
from pydantic import BaseModel, Field | |
from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction | |
from lagent.agents.agent import Agent, AsyncAgent | |
from lagent.agents.aggregator import DefaultAggregator | |
from lagent.hooks import ActionPreprocessor | |
from lagent.llms import BaseLLM | |
from lagent.memory import Memory | |
from lagent.prompts.parsers.json_parser import JSONParser | |
from lagent.prompts.prompt_template import PromptTemplate | |
from lagent.schema import AgentMessage | |
from lagent.utils import create_object | |
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: | |
{action_info} | |
{output_format} | |
开始!""" | |
output_format_template = """如果使用工具请遵循以下格式回复: | |
{function_format} | |
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 | |
{finish_format}""" | |
class ReAct(Agent): | |
def __init__(self, | |
llm: Union[BaseLLM, Dict], | |
actions: Union[BaseAction, List[BaseAction]], | |
template: Union[PromptTemplate, str] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Dict = dict(type=JSONParser), | |
aggregator: Dict = dict(type=DefaultAggregator), | |
hooks: List = [dict(type=ActionPreprocessor)], | |
finish_condition: Callable[[AgentMessage], bool] = lambda m: | |
'conclusion' in m.content or 'conclusion' in m.formatted, | |
max_turn: int = 5, | |
**kwargs): | |
self.max_turn = max_turn | |
self.finish_condition = finish_condition | |
actions = dict( | |
type=ActionExecutor, | |
actions=actions, | |
hooks=hooks, | |
) | |
self.actions: ActionExecutor = create_object(actions) | |
select_agent = dict( | |
type=Agent, | |
llm=llm, | |
template=template.format( | |
action_info=json.dumps(self.actions.description()), | |
output_format=output_format.format_instruction()), | |
output_format=output_format, | |
memory=memory, | |
aggregator=aggregator, | |
hooks=hooks, | |
) | |
self.select_agent = create_object(select_agent) | |
super().__init__(**kwargs) | |
def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: | |
for _ in range(self.max_turn): | |
message = self.select_agent(message) | |
if self.finish_condition(message): | |
return message | |
message = self.actions(message) | |
return message | |
class AsyncReAct(AsyncAgent): | |
def __init__(self, | |
llm: Union[BaseLLM, Dict], | |
actions: Union[BaseAction, List[BaseAction]], | |
template: Union[PromptTemplate, str] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Dict = dict(type=JSONParser), | |
aggregator: Dict = dict(type=DefaultAggregator), | |
hooks: List = [dict(type=ActionPreprocessor)], | |
finish_condition: Callable[[AgentMessage], bool] = lambda m: | |
'conclusion' in m.content or 'conclusion' in m.formatted, | |
max_turn: int = 5, | |
**kwargs): | |
self.max_turn = max_turn | |
self.finish_condition = finish_condition | |
actions = dict( | |
type=AsyncActionExecutor, | |
actions=actions, | |
hooks=hooks, | |
) | |
self.actions: AsyncActionExecutor = create_object(actions) | |
select_agent = dict( | |
type=AsyncAgent, | |
llm=llm, | |
template=template.format( | |
action_info=json.dumps(self.actions.description()), | |
output_format=output_format.format_instruction()), | |
output_format=output_format, | |
memory=memory, | |
aggregator=aggregator, | |
hooks=hooks, | |
) | |
self.select_agent = create_object(select_agent) | |
super().__init__(**kwargs) | |
async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: | |
for _ in range(self.max_turn): | |
message = await self.select_agent(message) | |
if self.finish_condition(message): | |
return message | |
message = await self.actions(message) | |
return message | |
if __name__ == '__main__': | |
from lagent.llms import GPTAPI | |
class ActionCall(BaseModel): | |
name: str = Field(description='调用的函数名称') | |
parameters: Dict = Field(description='调用函数的参数') | |
class ActionFormat(BaseModel): | |
thought_process: str = Field( | |
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') | |
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') | |
class FinishFormat(BaseModel): | |
thought_process: str = Field( | |
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') | |
conclusion: str = Field(description='总结当前的搜索结果,回答问题。') | |
prompt_template = PromptTemplate(select_action_template) | |
output_format = JSONParser( | |
output_format_template, | |
function_format=ActionFormat, | |
finish_format=FinishFormat) | |
llm = dict( | |
type=GPTAPI, | |
model_type='gpt-4o-2024-05-13', | |
key=None, | |
max_new_tokens=4096, | |
proxies=dict(), | |
retry=1000) | |
agent = ReAct( | |
llm=llm, | |
template=prompt_template, | |
output_format=output_format, | |
aggregator=dict(type='DefaultAggregator'), | |
actions=[dict(type='PythonInterpreter')], | |
) | |
response = agent( | |
AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) | |
print(response) | |
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) | |
print(response) | |