Переглянути джерело

refac: tool server redis cache

Timothy Jaeryang Baek 1 місяць тому
батько
коміт
f592748011

+ 2 - 5
backend/open_webui/routers/configs.py

@@ -9,8 +9,8 @@ from open_webui.config import BannerModel
 
 from open_webui.utils.tools import (
     get_tool_server_data,
-    get_tool_servers_data,
     get_tool_server_url,
+    set_tool_servers,
 )
 
 
@@ -114,10 +114,7 @@ async def set_tool_servers_config(
     request.app.state.config.TOOL_SERVER_CONNECTIONS = [
         connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
     ]
-
-    request.app.state.TOOL_SERVERS = await get_tool_servers_data(
-        request.app.state.config.TOOL_SERVER_CONNECTIONS
-    )
+    await set_tool_servers(request)
 
     return {
         "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,

+ 4 - 12
backend/open_webui/routers/tools.py

@@ -19,7 +19,7 @@ from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
 from open_webui.utils.tools import get_tool_specs
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
-from open_webui.utils.tools import get_tool_servers_data
+from open_webui.utils.tools import get_tool_servers
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
@@ -32,6 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 router = APIRouter()
 
+
 ############################
 # GetTools
 ############################
@@ -39,18 +40,9 @@ router = APIRouter()
 
 @router.get("/", response_model=list[ToolUserResponse])
 async def get_tools(request: Request, user=Depends(get_verified_user)):
-
-    if not request.app.state.TOOL_SERVERS:
-        # If the tool servers are not set, we need to set them
-        # This is done only once when the server starts
-        # This is done to avoid loading the tool servers every time
-
-        request.app.state.TOOL_SERVERS = await get_tool_servers_data(
-            request.app.state.config.TOOL_SERVER_CONNECTIONS
-        )
-
     tools = Tools.get_tools()
-    for server in request.app.state.TOOL_SERVERS:
+
+    for server in await get_tool_servers(request):
         tools.append(
             ToolUserResponse(
                 **{

+ 1 - 1
backend/open_webui/utils/middleware.py

@@ -910,7 +910,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     tools_dict = {}
 
     if tool_ids:
-        tools_dict = get_tools(
+        tools_dict = await get_tools(
             request,
             tool_ids,
             user,

+ 27 - 2
backend/open_webui/utils/tools.py

@@ -68,7 +68,7 @@ def get_async_tool_function_and_apply_extra_params(
         return new_function
 
 
-def get_tools(
+async def get_tools(
     request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
 ) -> dict[str, dict]:
     tools_dict = {}
@@ -80,7 +80,7 @@ def get_tools(
                 server_id = tool_id.split(":")[1]
 
                 tool_server_data = None
-                for server in request.app.state.TOOL_SERVERS:
+                for server in await get_tool_servers(request):
                     if server["id"] == server_id:
                         tool_server_data = server
                         break
@@ -447,6 +447,31 @@ def convert_openapi_to_tool_payload(openapi_spec):
     return tool_payload
 
 
+async def set_tool_servers(request: Request):
+    request.app.state.TOOL_SERVERS = await get_tool_servers_data(
+        request.app.state.config.TOOL_SERVER_CONNECTIONS
+    )
+
+    if request.app.state.redis is not None:
+        await request.app.state.redis.hmset(
+            "tool_servers", request.app.state.TOOL_SERVERS
+        )
+
+    return request.app.state.TOOL_SERVERS
+
+
+async def get_tool_servers(request: Request):
+    tool_servers = []
+    if request.app.state.redis is not None:
+        tool_servers = await request.app.state.redis.hgetall("tool_servers")
+
+    if not tool_servers:
+        await set_tool_servers(request)
+
+    request.app.state.TOOL_SERVERS = tool_servers
+    return request.app.state.TOOL_SERVERS
+
+
 async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
     headers = {
         "Accept": "application/json",