Spaces:
Running
Running
from pathlib import Path | |
from typing import Optional | |
from open_webui.models.tools import ( | |
ToolForm, | |
ToolModel, | |
ToolResponse, | |
ToolUserResponse, | |
Tools, | |
) | |
from open_webui.utils.plugin import load_tools_module_by_id, replace_imports | |
from open_webui.config import CACHE_DIR | |
from open_webui.constants import ERROR_MESSAGES | |
from fastapi import APIRouter, Depends, HTTPException, Request, status | |
from open_webui.utils.tools import get_tools_specs | |
from open_webui.utils.auth import get_admin_user, get_verified_user | |
from open_webui.utils.access_control import has_access, has_permission | |
router = APIRouter() | |
############################ | |
# GetTools | |
############################ | |
async def get_tools(user=Depends(get_verified_user)): | |
if user.role == "admin": | |
tools = Tools.get_tools() | |
else: | |
tools = Tools.get_tools_by_user_id(user.id, "read") | |
return tools | |
############################ | |
# GetToolList | |
############################ | |
async def get_tool_list(user=Depends(get_verified_user)): | |
if user.role == "admin": | |
tools = Tools.get_tools() | |
else: | |
tools = Tools.get_tools_by_user_id(user.id, "write") | |
return tools | |
############################ | |
# ExportTools | |
############################ | |
async def export_tools(user=Depends(get_admin_user)): | |
tools = Tools.get_tools() | |
return tools | |
############################ | |
# CreateNewTools | |
############################ | |
async def create_new_tools( | |
request: Request, | |
form_data: ToolForm, | |
user=Depends(get_verified_user), | |
): | |
if user.role != "admin" and not has_permission( | |
user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS | |
): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.UNAUTHORIZED, | |
) | |
if not form_data.id.isidentifier(): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Only alphanumeric characters and underscores are allowed in the id", | |
) | |
form_data.id = form_data.id.lower() | |
tools = Tools.get_tool_by_id(form_data.id) | |
if tools is None: | |
try: | |
form_data.content = replace_imports(form_data.content) | |
tools_module, frontmatter = load_tools_module_by_id( | |
form_data.id, content=form_data.content | |
) | |
form_data.meta.manifest = frontmatter | |
TOOLS = request.app.state.TOOLS | |
TOOLS[form_data.id] = tools_module | |
specs = get_tools_specs(TOOLS[form_data.id]) | |
tools = Tools.insert_new_tool(user.id, form_data, specs) | |
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id | |
tool_cache_dir.mkdir(parents=True, exist_ok=True) | |
if tools: | |
return tools | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"), | |
) | |
except Exception as e: | |
print(e) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(str(e)), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.ID_TAKEN, | |
) | |
############################ | |
# GetToolsById | |
############################ | |
async def get_tools_by_id(id: str, user=Depends(get_verified_user)): | |
tools = Tools.get_tool_by_id(id) | |
if tools: | |
if ( | |
user.role == "admin" | |
or tools.user_id == user.id | |
or has_access(user.id, "read", tools.access_control) | |
): | |
return tools | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# UpdateToolsById | |
############################ | |
async def update_tools_by_id( | |
request: Request, | |
id: str, | |
form_data: ToolForm, | |
user=Depends(get_verified_user), | |
): | |
tools = Tools.get_tool_by_id(id) | |
if not tools: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
if tools.user_id != user.id and user.role != "admin": | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.UNAUTHORIZED, | |
) | |
try: | |
form_data.content = replace_imports(form_data.content) | |
tools_module, frontmatter = load_tools_module_by_id( | |
id, content=form_data.content | |
) | |
form_data.meta.manifest = frontmatter | |
TOOLS = request.app.state.TOOLS | |
TOOLS[id] = tools_module | |
specs = get_tools_specs(TOOLS[id]) | |
updated = { | |
**form_data.model_dump(exclude={"id"}), | |
"specs": specs, | |
} | |
print(updated) | |
tools = Tools.update_tool_by_id(id, updated) | |
if tools: | |
return tools | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT("Error updating tools"), | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(str(e)), | |
) | |
############################ | |
# DeleteToolsById | |
############################ | |
async def delete_tools_by_id( | |
request: Request, id: str, user=Depends(get_verified_user) | |
): | |
tools = Tools.get_tool_by_id(id) | |
if not tools: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
if tools.user_id != user.id and user.role != "admin": | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.UNAUTHORIZED, | |
) | |
result = Tools.delete_tool_by_id(id) | |
if result: | |
TOOLS = request.app.state.TOOLS | |
if id in TOOLS: | |
del TOOLS[id] | |
return result | |
############################ | |
# GetToolValves | |
############################ | |
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): | |
tools = Tools.get_tool_by_id(id) | |
if tools: | |
try: | |
valves = Tools.get_tool_valves_by_id(id) | |
return valves | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(str(e)), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# GetToolValvesSpec | |
############################ | |
async def get_tools_valves_spec_by_id( | |
request: Request, id: str, user=Depends(get_verified_user) | |
): | |
tools = Tools.get_tool_by_id(id) | |
if tools: | |
if id in request.app.state.TOOLS: | |
tools_module = request.app.state.TOOLS[id] | |
else: | |
tools_module, _ = load_tools_module_by_id(id) | |
request.app.state.TOOLS[id] = tools_module | |
if hasattr(tools_module, "Valves"): | |
Valves = tools_module.Valves | |
return Valves.schema() | |
return None | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
############################ | |
# UpdateToolValves | |
############################ | |
async def update_tools_valves_by_id( | |
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) | |
): | |
tools = Tools.get_tool_by_id(id) | |
if not tools: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
if id in request.app.state.TOOLS: | |
tools_module = request.app.state.TOOLS[id] | |
else: | |
tools_module, _ = load_tools_module_by_id(id) | |
request.app.state.TOOLS[id] = tools_module | |
if not hasattr(tools_module, "Valves"): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
Valves = tools_module.Valves | |
try: | |
form_data = {k: v for k, v in form_data.items() if v is not None} | |
valves = Valves(**form_data) | |
Tools.update_tool_valves_by_id(id, valves.model_dump()) | |
return valves.model_dump() | |
except Exception as e: | |
print(e) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(str(e)), | |
) | |
############################ | |
# ToolUserValves | |
############################ | |
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): | |
tools = Tools.get_tool_by_id(id) | |
if tools: | |
try: | |
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) | |
return user_valves | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(str(e)), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
async def get_tools_user_valves_spec_by_id( | |
request: Request, id: str, user=Depends(get_verified_user) | |
): | |
tools = Tools.get_tool_by_id(id) | |
if tools: | |
if id in request.app.state.TOOLS: | |
tools_module = request.app.state.TOOLS[id] | |
else: | |
tools_module, _ = load_tools_module_by_id(id) | |
request.app.state.TOOLS[id] = tools_module | |
if hasattr(tools_module, "UserValves"): | |
UserValves = tools_module.UserValves | |
return UserValves.schema() | |
return None | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
async def update_tools_user_valves_by_id( | |
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) | |
): | |
tools = Tools.get_tool_by_id(id) | |
if tools: | |
if id in request.app.state.TOOLS: | |
tools_module = request.app.state.TOOLS[id] | |
else: | |
tools_module, _ = load_tools_module_by_id(id) | |
request.app.state.TOOLS[id] = tools_module | |
if hasattr(tools_module, "UserValves"): | |
UserValves = tools_module.UserValves | |
try: | |
form_data = {k: v for k, v in form_data.items() if v is not None} | |
user_valves = UserValves(**form_data) | |
Tools.update_user_valves_by_id_and_user_id( | |
id, user.id, user_valves.model_dump() | |
) | |
return user_valves.model_dump() | |
except Exception as e: | |
print(e) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=ERROR_MESSAGES.DEFAULT(str(e)), | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=ERROR_MESSAGES.NOT_FOUND, | |
) | |