Przeglądaj źródła

Prevent duplicate function module loads with caching helper and refactor

Gunwoo Hur 4 miesięcy temu
rodzic
commit
14c3d0c2d1

+ 5 - 4
backend/open_webui/functions.py

@@ -28,7 +28,10 @@ from open_webui.socket.main import (
 from open_webui.models.functions import Functions
 from open_webui.models.models import Models
 
-from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.plugin import (
+    load_function_module_by_id,
+    get_function_module_from_cache,
+)
 from open_webui.utils.tools import get_tools
 from open_webui.utils.access_control import has_access
 
@@ -53,9 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
 def get_function_module_by_id(request: Request, pipe_id: str):
-    # Check if function is already loaded
-    function_module, _, _ = load_function_module_by_id(pipe_id)
-    request.app.state.FUNCTIONS[pipe_id] = function_module
+    function_module, _, _ = get_function_module_from_cache(request, pipe_id)
 
     if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
         valves = Functions.get_function_valves_by_id(pipe_id)

+ 9 - 9
backend/open_webui/routers/functions.py

@@ -12,7 +12,11 @@ from open_webui.models.functions import (
     FunctionResponse,
     Functions,
 )
-from open_webui.utils.plugin import load_function_module_by_id, replace_imports
+from open_webui.utils.plugin import (
+    load_function_module_by_id,
+    replace_imports,
+    get_function_module_from_cache,
+)
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -358,8 +362,7 @@ async def get_function_valves_spec_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        function_module, function_type, frontmatter = load_function_module_by_id(id)
-        request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
 
         if hasattr(function_module, "Valves"):
             Valves = function_module.Valves
@@ -383,8 +386,7 @@ async def update_function_valves_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        function_module, function_type, frontmatter = load_function_module_by_id(id)
-        request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
 
         if hasattr(function_module, "Valves"):
             Valves = function_module.Valves
@@ -443,8 +445,7 @@ async def get_function_user_valves_spec_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        function_module, function_type, frontmatter = load_function_module_by_id(id)
-        request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
 
         if hasattr(function_module, "UserValves"):
             UserValves = function_module.UserValves
@@ -464,8 +465,7 @@ async def update_function_user_valves_by_id(
     function = Functions.get_function_by_id(id)
 
     if function:
-        function_module, function_type, frontmatter = load_function_module_by_id(id)
-        request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = get_function_module_from_cache(request, id)
 
         if hasattr(function_module, "UserValves"):
             UserValves = function_module.UserValves

+ 5 - 3
backend/open_webui/utils/chat.py

@@ -40,7 +40,10 @@ from open_webui.models.functions import Functions
 from open_webui.models.models import Models
 
 
-from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.plugin import (
+    load_function_module_by_id,
+    get_function_module_from_cache,
+)
 from open_webui.utils.models import get_all_models, check_model_access
 from open_webui.utils.payload import convert_payload_openai_to_ollama
 from open_webui.utils.response import (
@@ -392,8 +395,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
         }
     )
 
-    function_module, _, _ = load_function_module_by_id(action_id)
-    request.app.state.FUNCTIONS[action_id] = function_module
+    function_module, _, _ = get_function_module_from_cache(request, action_id)
 
     if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
         valves = Functions.get_function_valves_by_id(action_id)

+ 5 - 5
backend/open_webui/utils/filter.py

@@ -1,7 +1,10 @@
 import inspect
 import logging
 
-from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.plugin import (
+    load_function_module_by_id,
+    get_function_module_from_cache,
+)
 from open_webui.models.functions import Functions
 from open_webui.env import SRC_LOG_LEVELS
 
@@ -13,10 +16,7 @@ def get_function_module(request, function_id):
     """
     Get the function module by its ID.
     """
-
-    function_module, _, _ = load_function_module_by_id(function_id)
-    request.app.state.FUNCTIONS[function_id] = function_module
-
+    function_module, _, _ = get_function_module_from_cache(request, function_id)
     return function_module
 
 

+ 5 - 3
backend/open_webui/utils/models.py

@@ -13,7 +13,10 @@ from open_webui.models.functions import Functions
 from open_webui.models.models import Models
 
 
-from open_webui.utils.plugin import load_function_module_by_id
+from open_webui.utils.plugin import (
+    load_function_module_by_id,
+    get_function_module_from_cache,
+)
 from open_webui.utils.access_control import has_access
 
 
@@ -239,8 +242,7 @@ async def get_all_models(request, user: UserModel = None):
         ]
 
     def get_function_module_by_id(function_id):
-        function_module, _, _ = load_function_module_by_id(function_id)
-        request.app.state.FUNCTIONS[function_id] = function_module
+        function_module, _, _ = get_function_module_from_cache(request, function_id)
         return function_module
 
     for model in models:

+ 18 - 0
backend/open_webui/utils/plugin.py

@@ -166,6 +166,24 @@ def load_function_module_by_id(function_id, content=None):
         os.unlink(temp_file.name)
 
 
+def get_function_module_from_cache(request, function_id):
+    if (
+        hasattr(request.app.state, "FUNCTIONS")
+        and function_id in request.app.state.FUNCTIONS
+    ):
+        return request.app.state.FUNCTIONS[function_id], None, None
+
+    function_module, function_type, frontmatter = load_function_module_by_id(
+        function_id
+    )
+
+    if not hasattr(request.app.state, "FUNCTIONS"):
+        request.app.state.FUNCTIONS = {}
+
+    request.app.state.FUNCTIONS[function_id] = function_module
+    return function_module, function_type, frontmatter
+
+
 def install_frontmatter_requirements(requirements: str):
     if requirements:
         try: