File size: 12,486 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import json
from collections.abc import Mapping
from copy import deepcopy
from datetime import datetime, timezone
from mimetypes import guess_type
from typing import Any, Optional, Union

from yarl import URL

from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
from core.tools.errors import (
    ToolEngineInvokeError,
    ToolInvokeError,
    ToolNotFoundError,
    ToolNotSupportedError,
    ToolParameterValidationError,
    ToolProviderCredentialValidationError,
    ToolProviderNotFoundError,
)
from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.model import Message, MessageFile


class ToolEngine:
    """
    Tool runtime engine take care of the tool executions.
    """

    @staticmethod
    def agent_invoke(
        tool: Tool,
        tool_parameters: Union[str, dict],
        user_id: str,
        tenant_id: str,
        message: Message,
        invoke_from: InvokeFrom,
        agent_tool_callback: DifyAgentCallbackHandler,
        trace_manager: Optional[TraceQueueManager] = None,
    ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
        """
        Agent invokes the tool with the given arguments.
        """
        # check if arguments is a string
        if isinstance(tool_parameters, str):
            # check if this tool has only one parameter
            parameters = [
                parameter
                for parameter in tool.get_runtime_parameters() or []
                if parameter.form == ToolParameter.ToolParameterForm.LLM
            ]
            if parameters and len(parameters) == 1:
                tool_parameters = {parameters[0].name: tool_parameters}
            else:
                raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")

        # invoke the tool
        try:
            # hit the callback handler
            agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)

            meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
            response = ToolFileMessageTransformer.transform_tool_invoke_messages(
                messages=response, user_id=user_id, tenant_id=tenant_id, conversation_id=message.conversation_id
            )

            # extract binary data from tool invoke message
            binary_files = ToolEngine._extract_tool_response_binary(response)
            # create message file
            message_files = ToolEngine._create_message_files(
                tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
            )

            plain_text = ToolEngine._convert_tool_response_to_str(response)

            # hit the callback handler
            agent_tool_callback.on_tool_end(
                tool_name=tool.identity.name,
                tool_inputs=tool_parameters,
                tool_outputs=plain_text,
                message_id=message.id,
                trace_manager=trace_manager,
            )

            # transform tool invoke message to get LLM friendly message
            return plain_text, message_files, meta
        except ToolProviderCredentialValidationError as e:
            error_response = "Please check your tool provider credentials"
            agent_tool_callback.on_tool_error(e)
        except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
            error_response = f"there is not a tool named {tool.identity.name}"
            agent_tool_callback.on_tool_error(e)
        except ToolParameterValidationError as e:
            error_response = f"tool parameters validation error: {e}, please check your tool parameters"
            agent_tool_callback.on_tool_error(e)
        except ToolInvokeError as e:
            error_response = f"tool invoke error: {e}"
            agent_tool_callback.on_tool_error(e)
        except ToolEngineInvokeError as e:
            meta = e.args[0]
            error_response = f"tool invoke error: {meta.error}"
            agent_tool_callback.on_tool_error(e)
            return error_response, [], meta
        except Exception as e:
            error_response = f"unknown error: {e}"
            agent_tool_callback.on_tool_error(e)

        return error_response, [], ToolInvokeMeta.error_instance(error_response)

    @staticmethod
    def workflow_invoke(
        tool: Tool,
        tool_parameters: Mapping[str, Any],
        user_id: str,
        workflow_tool_callback: DifyWorkflowCallbackHandler,
        workflow_call_depth: int,
        thread_pool_id: Optional[str] = None,
    ) -> list[ToolInvokeMessage]:
        """
        Workflow invokes the tool with the given arguments.
        """
        try:
            # hit the callback handler
            assert tool.identity is not None
            workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)

            if isinstance(tool, WorkflowTool):
                tool.workflow_call_depth = workflow_call_depth + 1
                tool.thread_pool_id = thread_pool_id

            if tool.runtime and tool.runtime.runtime_parameters:
                tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
            response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)

            # hit the callback handler
            workflow_tool_callback.on_tool_end(
                tool_name=tool.identity.name,
                tool_inputs=tool_parameters,
                tool_outputs=response,
            )

            return response
        except Exception as e:
            workflow_tool_callback.on_tool_error(e)
            raise e

    @staticmethod
    def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
        """
        Invoke the tool with the given arguments.
        """
        started_at = datetime.now(timezone.utc)
        meta = ToolInvokeMeta(
            time_cost=0.0,
            error=None,
            tool_config={
                "tool_name": tool.identity.name,
                "tool_provider": tool.identity.provider,
                "tool_provider_type": tool.tool_provider_type().value,
                "tool_parameters": deepcopy(tool.runtime.runtime_parameters),
                "tool_icon": tool.identity.icon,
            },
        )
        try:
            response = tool.invoke(user_id, tool_parameters)
        except Exception as e:
            meta.error = str(e)
            raise ToolEngineInvokeError(meta)
        finally:
            ended_at = datetime.now(timezone.utc)
            meta.time_cost = (ended_at - started_at).total_seconds()

        return meta, response

    @staticmethod
    def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
        """
        Handle tool response
        """
        result = ""
        for response in tool_response:
            if response.type == ToolInvokeMessage.MessageType.TEXT:
                result += response.message
            elif response.type == ToolInvokeMessage.MessageType.LINK:
                result += f"result link: {response.message}. please tell user to check it."
            elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                result += (
                    "image has been created and sent to user already, you do not need to create it,"
                    " just tell the user to check it now."
                )
            elif response.type == ToolInvokeMessage.MessageType.JSON:
                result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
            else:
                result += f"tool response: {response.message}."

        return result

    @staticmethod
    def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
        """
        Extract tool response binary
        """
        result = []

        for response in tool_response:
            if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                mimetype = None
                if response.meta.get("mime_type"):
                    mimetype = response.meta.get("mime_type")
                else:
                    try:
                        url = URL(response.message)
                        extension = url.suffix
                        guess_type_result, _ = guess_type(f"a{extension}")
                        if guess_type_result:
                            mimetype = guess_type_result
                    except Exception:
                        pass

                if not mimetype:
                    mimetype = "image/jpeg"

                result.append(
                    ToolInvokeMessageBinary(
                        mimetype=response.meta.get("mime_type", "image/jpeg"),
                        url=response.message,
                        save_as=response.save_as,
                    )
                )
            elif response.type == ToolInvokeMessage.MessageType.BLOB:
                result.append(
                    ToolInvokeMessageBinary(
                        mimetype=response.meta.get("mime_type", "octet/stream"),
                        url=response.message,
                        save_as=response.save_as,
                    )
                )
            elif response.type == ToolInvokeMessage.MessageType.LINK:
                # check if there is a mime type in meta
                if response.meta and "mime_type" in response.meta:
                    result.append(
                        ToolInvokeMessageBinary(
                            mimetype=response.meta.get("mime_type", "octet/stream")
                            if response.meta
                            else "octet/stream",
                            url=response.message,
                            save_as=response.save_as,
                        )
                    )

        return result

    @staticmethod
    def _create_message_files(
        tool_messages: list[ToolInvokeMessageBinary],
        agent_message: Message,
        invoke_from: InvokeFrom,
        user_id: str,
    ) -> list[tuple[Any, str]]:
        """
        Create message file

        :param messages: messages
        :return: message files, should save as variable
        """
        result = []

        for message in tool_messages:
            if "image" in message.mimetype:
                file_type = FileType.IMAGE
            elif "video" in message.mimetype:
                file_type = FileType.VIDEO
            elif "audio" in message.mimetype:
                file_type = FileType.AUDIO
            elif "text" in message.mimetype or "pdf" in message.mimetype:
                file_type = FileType.DOCUMENT
            else:
                file_type = FileType.CUSTOM

            # extract tool file id from url
            tool_file_id = message.url.split("/")[-1].split(".")[0]
            message_file = MessageFile(
                message_id=agent_message.id,
                type=file_type,
                transfer_method=FileTransferMethod.TOOL_FILE,
                belongs_to="assistant",
                url=message.url,
                upload_file_id=tool_file_id,
                created_by_role=(
                    CreatedByRole.ACCOUNT
                    if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
                    else CreatedByRole.END_USER
                ),
                created_by=user_id,
            )

            db.session.add(message_file)
            db.session.commit()
            db.session.refresh(message_file)

            result.append((message_file.id, message.save_as))

        db.session.close()

        return result