File size: 6,138 Bytes
8437908 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import inspect
import logging
from typing import Awaitable, Callable, get_type_hints
from open_webui.apps.webui.models.tools import Tools
from open_webui.apps.webui.models.users import UserModel
from open_webui.apps.webui.utils import load_toolkit_module_by_id
from open_webui.utils.schemas import json_schema_to_model
log = logging.getLogger(__name__)
def apply_extra_params_to_tool_function(
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
async def new_function(**kwargs):
extra_kwargs = kwargs | extra_params
if is_coroutine:
return await function(**extra_kwargs)
return function(**extra_kwargs)
return new_function
# Mutation on extra_params
def get_tools(
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools = {}
for tool_id in tool_ids:
toolkit = Tools.get_tool_by_id(tool_id)
if toolkit is None:
continue
module = webui_app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
function_name = spec["name"]
# convert to function that takes only model params and inserts custom params
original_func = getattr(module, function_name)
callable = apply_extra_params_to_tool_function(original_func, extra_params)
if hasattr(original_func, "__doc__"):
callable.__doc__ = original_func.__doc__
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,
"callable": callable,
"spec": spec,
"pydantic_model": json_schema_to_model(spec),
"file_handler": hasattr(module, "file_handler") and module.file_handler,
"citation": hasattr(module, "citation") and module.citation,
}
# TODO: if collision, prepend toolkit name
if function_name in tools:
log.warning(f"Tool {function_name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
log.warning(f"Discarding {toolkit}.{function_name}")
else:
tools[function_name] = tool_dict
return tools
def doc_to_dict(docstring):
lines = docstring.split("\n")
description = lines[1].strip()
param_dict = {}
for line in lines:
if ":param" in line:
line = line.replace(":param", "").strip()
param, desc = line.split(":", 1)
param_dict[param.strip()] = desc.strip()
ret_dict = {"description": description, "params": param_dict}
return ret_dict
def get_tools_specs(tools) -> list[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
if callable(getattr(tools, func))
and not func.startswith("__")
and not inspect.isclass(getattr(tools, func))
]
specs = []
for function_item in function_list:
function_name = function_item["name"]
function = function_item["function"]
function_doc = doc_to_dict(function.__doc__ or function_name)
specs.append(
{
"name": function_name,
# TODO: multi-line desc?
"description": function_doc.get("description", function_name),
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_annotation.__name__.lower(),
**(
{
"enum": (
str(param_annotation.__args__)
if hasattr(param_annotation, "__args__")
else None
)
}
if hasattr(param_annotation, "__args__")
else {}
),
"description": function_doc.get("params", {}).get(
param_name, param_name
),
}
for param_name, param_annotation in get_type_hints(
function
).items()
if param_name != "return"
and not (
param_name.startswith("__") and param_name.endswith("__")
)
},
"required": [
name
for name, param in inspect.signature(
function
).parameters.items()
if param.default is param.empty
and not (name.startswith("__") and name.endswith("__"))
],
},
}
)
return specs
|