File size: 6,099 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
from typing import Callable, Dict, List, Union

from pydantic import BaseModel, Field

from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction
from lagent.agents.agent import Agent, AsyncAgent
from lagent.agents.aggregator import DefaultAggregator
from lagent.hooks import ActionPreprocessor
from lagent.llms import BaseLLM
from lagent.memory import Memory
from lagent.prompts.parsers.json_parser import JSONParser
from lagent.prompts.prompt_template import PromptTemplate
from lagent.schema import AgentMessage
from lagent.utils import create_object

select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
{action_info}
{output_format}
开始!"""

output_format_template = """如果使用工具请遵循以下格式回复:
{function_format}

如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
{finish_format}"""


class ReAct(Agent):

    def __init__(self,
                 llm: Union[BaseLLM, Dict],
                 actions: Union[BaseAction, List[BaseAction]],
                 template: Union[PromptTemplate, str] = None,
                 memory: Dict = dict(type=Memory),
                 output_format: Dict = dict(type=JSONParser),
                 aggregator: Dict = dict(type=DefaultAggregator),
                 hooks: List = [dict(type=ActionPreprocessor)],
                 finish_condition: Callable[[AgentMessage], bool] = lambda m:
                 'conclusion' in m.content or 'conclusion' in m.formatted,
                 max_turn: int = 5,
                 **kwargs):
        self.max_turn = max_turn
        self.finish_condition = finish_condition
        actions = dict(
            type=ActionExecutor,
            actions=actions,
            hooks=hooks,
        )
        self.actions: ActionExecutor = create_object(actions)
        select_agent = dict(
            type=Agent,
            llm=llm,
            template=template.format(
                action_info=json.dumps(self.actions.description()),
                output_format=output_format.format_instruction()),
            output_format=output_format,
            memory=memory,
            aggregator=aggregator,
            hooks=hooks,
        )
        self.select_agent = create_object(select_agent)
        super().__init__(**kwargs)

    def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
        for _ in range(self.max_turn):
            message = self.select_agent(message)
            if self.finish_condition(message):
                return message
            message = self.actions(message)
        return message


class AsyncReAct(AsyncAgent):

    def __init__(self,
                 llm: Union[BaseLLM, Dict],
                 actions: Union[BaseAction, List[BaseAction]],
                 template: Union[PromptTemplate, str] = None,
                 memory: Dict = dict(type=Memory),
                 output_format: Dict = dict(type=JSONParser),
                 aggregator: Dict = dict(type=DefaultAggregator),
                 hooks: List = [dict(type=ActionPreprocessor)],
                 finish_condition: Callable[[AgentMessage], bool] = lambda m:
                 'conclusion' in m.content or 'conclusion' in m.formatted,
                 max_turn: int = 5,
                 **kwargs):
        self.max_turn = max_turn
        self.finish_condition = finish_condition
        actions = dict(
            type=AsyncActionExecutor,
            actions=actions,
            hooks=hooks,
        )
        self.actions: AsyncActionExecutor = create_object(actions)
        select_agent = dict(
            type=AsyncAgent,
            llm=llm,
            template=template.format(
                action_info=json.dumps(self.actions.description()),
                output_format=output_format.format_instruction()),
            output_format=output_format,
            memory=memory,
            aggregator=aggregator,
            hooks=hooks,
        )
        self.select_agent = create_object(select_agent)
        super().__init__(**kwargs)

    async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
        for _ in range(self.max_turn):
            message = await self.select_agent(message)
            if self.finish_condition(message):
                return message
            message = await self.actions(message)
        return message


if __name__ == '__main__':
    from lagent.llms import GPTAPI

    class ActionCall(BaseModel):
        name: str = Field(description='调用的函数名称')
        parameters: Dict = Field(description='调用函数的参数')

    class ActionFormat(BaseModel):
        thought_process: str = Field(
            description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
        action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。')

    class FinishFormat(BaseModel):
        thought_process: str = Field(
            description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
        conclusion: str = Field(description='总结当前的搜索结果,回答问题。')

    prompt_template = PromptTemplate(select_action_template)
    output_format = JSONParser(
        output_format_template,
        function_format=ActionFormat,
        finish_format=FinishFormat)

    llm = dict(
        type=GPTAPI,
        model_type='gpt-4o-2024-05-13',
        key=None,
        max_new_tokens=4096,
        proxies=dict(),
        retry=1000)

    agent = ReAct(
        llm=llm,
        template=prompt_template,
        output_format=output_format,
        aggregator=dict(type='DefaultAggregator'),
        actions=[dict(type='PythonInterpreter')],
    )
    response = agent(
        AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
    print(response)
    response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
    print(response)