Timothy Jaeryang Baek 1 týždeň pred
rodič
commit
54beeeaf72

+ 25 - 9
backend/open_webui/routers/tools.py

@@ -17,7 +17,11 @@ from open_webui.models.tools import (
     ToolUserResponse,
     Tools,
 )
-from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
+from open_webui.utils.plugin import (
+    load_tool_module_by_id,
+    replace_imports,
+    get_tool_module_from_cache,
+)
 from open_webui.utils.tools import get_tool_specs
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
@@ -35,6 +39,14 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 router = APIRouter()
 
 
+def get_tool_module(request, tool_id, load_from_db=True):
+    """
+    Get the tool module by its ID.
+    """
+    tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db)
+    return tool_module
+
+
 ############################
 # GetTools
 ############################
@@ -42,15 +54,19 @@ router = APIRouter()
 
 @router.get("/", response_model=list[ToolUserResponse])
 async def get_tools(request: Request, user=Depends(get_verified_user)):
-    tools = [
-        ToolUserResponse(
-            **{
-                **tool.model_dump(),
-                "has_user_valves": "class UserValves(BaseModel):" in tool.content,
-            }
+    tools = []
+
+    # Local Tools
+    for tool in Tools.get_tools():
+        tool_module = get_tool_module(request, tool.id)
+        tools.append(
+            ToolUserResponse(
+                **{
+                    **tool.model_dump(),
+                    "has_user_valves": hasattr(tool_module, "UserValves"),
+                }
+            )
         )
-        for tool in Tools.get_tools()
-    ]
 
     # OpenAPI Tool Servers
     for server in await get_tool_servers(request):

+ 42 - 0
backend/open_webui/utils/plugin.py

@@ -166,6 +166,48 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
         os.unlink(temp_file.name)
 
 
+def get_tool_module_from_cache(request, tool_id, load_from_db=True):
+    if load_from_db:
+        # Always load from the database by default
+        tool = Tools.get_tool_by_id(tool_id)
+        if not tool:
+            raise Exception(f"Tool not found: {tool_id}")
+        content = tool.content
+
+        new_content = replace_imports(content)
+        if new_content != content:
+            content = new_content
+            # Update the tool content in the database
+            Tools.update_tool_by_id(tool_id, {"content": content})
+
+        if (
+            hasattr(request.app.state, "TOOL_CONTENTS")
+            and tool_id in request.app.state.TOOL_CONTENTS
+        ) and (
+            hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS
+        ):
+            if request.app.state.TOOL_CONTENTS[tool_id] == content:
+                return request.app.state.TOOLS[tool_id], None
+
+        tool_module, frontmatter = load_tool_module_by_id(tool_id, content)
+    else:
+        if hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS:
+            return request.app.state.TOOLS[tool_id], None
+
+        tool_module, frontmatter = load_tool_module_by_id(tool_id)
+
+    if not hasattr(request.app.state, "TOOLS"):
+        request.app.state.TOOLS = {}
+
+    if not hasattr(request.app.state, "TOOL_CONTENTS"):
+        request.app.state.TOOL_CONTENTS = {}
+
+    request.app.state.TOOLS[tool_id] = tool_module
+    request.app.state.TOOL_CONTENTS[tool_id] = content
+
+    return tool_module, frontmatter
+
+
 def get_function_module_from_cache(request, function_id, load_from_db=True):
     if load_from_db:
         # Always load from the database by default