Browse Source

fix/refac: functions multi-replica issue

Timothy Jaeryang Baek 4 months ago
parent
commit
74ace200fe

+ 2 - 5
backend/open_webui/functions.py

@@ -54,11 +54,8 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 def get_function_module_by_id(request: Request, pipe_id: str):
     # Check if function is already loaded
-    if pipe_id not in request.app.state.FUNCTIONS:
-        function_module, _, _ = load_function_module_by_id(pipe_id)
-        request.app.state.FUNCTIONS[pipe_id] = function_module
-    else:
-        function_module = request.app.state.FUNCTIONS[pipe_id]
+    function_module, _, _ = load_function_module_by_id(pipe_id)
+    request.app.state.FUNCTIONS[pipe_id] = function_module
 
     if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
         valves = Functions.get_function_valves_by_id(pipe_id)

+ 8 - 20
backend/open_webui/routers/functions.py

@@ -262,11 +262,8 @@ async def get_function_valves_spec_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        if id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[id]
-        else:
-            function_module, function_type, frontmatter = load_function_module_by_id(id)
-            request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = load_function_module_by_id(id)
+        request.app.state.FUNCTIONS[id] = function_module
 
         if hasattr(function_module, "Valves"):
             Valves = function_module.Valves
@@ -290,11 +287,8 @@ async def update_function_valves_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        if id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[id]
-        else:
-            function_module, function_type, frontmatter = load_function_module_by_id(id)
-            request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = load_function_module_by_id(id)
+        request.app.state.FUNCTIONS[id] = function_module
 
         if hasattr(function_module, "Valves"):
             Valves = function_module.Valves
@@ -353,11 +347,8 @@ async def get_function_user_valves_spec_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-        if id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[id]
-        else:
-            function_module, function_type, frontmatter = load_function_module_by_id(id)
-            request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = load_function_module_by_id(id)
+        request.app.state.FUNCTIONS[id] = function_module
 
         if hasattr(function_module, "UserValves"):
             UserValves = function_module.UserValves
@@ -377,11 +368,8 @@ async def update_function_user_valves_by_id(
     function = Functions.get_function_by_id(id)
 
     if function:
-        if id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[id]
-        else:
-            function_module, function_type, frontmatter = load_function_module_by_id(id)
-            request.app.state.FUNCTIONS[id] = function_module
+        function_module, function_type, frontmatter = load_function_module_by_id(id)
+        request.app.state.FUNCTIONS[id] = function_module
 
         if hasattr(function_module, "UserValves"):
             UserValves = function_module.UserValves

+ 2 - 5
backend/open_webui/utils/chat.py

@@ -392,11 +392,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
         }
     )
 
-    if action_id in request.app.state.FUNCTIONS:
-        function_module = request.app.state.FUNCTIONS[action_id]
-    else:
-        function_module, _, _ = load_function_module_by_id(action_id)
-        request.app.state.FUNCTIONS[action_id] = function_module
+    function_module, _, _ = load_function_module_by_id(action_id)
+    request.app.state.FUNCTIONS[action_id] = function_module
 
     if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
         valves = Functions.get_function_valves_by_id(action_id)

+ 3 - 5
backend/open_webui/utils/filter.py

@@ -13,11 +13,9 @@ def get_function_module(request, function_id):
     """
     Get the function module by its ID.
     """
-    if function_id in request.app.state.FUNCTIONS:
-        function_module = request.app.state.FUNCTIONS[function_id]
-    else:
-        function_module, _, _ = load_function_module_by_id(function_id)
-        request.app.state.FUNCTIONS[function_id] = function_module
+
+    function_module, _, _ = load_function_module_by_id(function_id)
+    request.app.state.FUNCTIONS[function_id] = function_module
 
     return function_module
 

+ 2 - 5
backend/open_webui/utils/models.py

@@ -239,11 +239,8 @@ async def get_all_models(request, user: UserModel = None):
         ]
 
     def get_function_module_by_id(function_id):
-        if function_id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[function_id]
-        else:
-            function_module, _, _ = load_function_module_by_id(function_id)
-            request.app.state.FUNCTIONS[function_id] = function_module
+        function_module, _, _ = load_function_module_by_id(function_id)
+        request.app.state.FUNCTIONS[function_id] = function_module
         return function_module
 
     for model in models: