浏览代码

refac/enh: function valves validation

Timothy Jaeryang Baek 3 周之前
父节点
当前提交
e66e0526ed
共有 2 个文件被更改,包括 71 次插入46 次删除
  1. 61 46
      backend/open_webui/functions.py
  2. 10 0
      backend/open_webui/routers/functions.py

+ 61 - 46
backend/open_webui/functions.py

@@ -19,6 +19,7 @@ from fastapi import (
 from starlette.responses import Response, StreamingResponse
 
 
+from open_webui.constants import ERROR_MESSAGES
 from open_webui.socket.main import (
     get_event_call,
     get_event_emitter,
@@ -60,8 +61,18 @@ def get_function_module_by_id(request: Request, pipe_id: str):
     function_module, _, _ = get_function_module_from_cache(request, pipe_id)
 
     if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+        Valves = function_module.Valves
         valves = Functions.get_function_valves_by_id(pipe_id)
-        function_module.valves = function_module.Valves(**(valves if valves else {}))
+
+        if valves:
+            try:
+                function_module.valves = Valves(
+                    **({k: v for k, v in valves if v is not None})
+                )
+            except Exception as e:
+                log.exception(f"Error loading valves for function {pipe_id}: {e}")
+                raise e
+
     return function_module
 
 
@@ -70,65 +81,69 @@ async def get_function_models(request):
     pipe_models = []
 
     for pipe in pipes:
-        function_module = get_function_module_by_id(request, pipe.id)
-
-        # Check if function is a manifold
-        if hasattr(function_module, "pipes"):
-            sub_pipes = []
+        try:
+            function_module = get_function_module_by_id(request, pipe.id)
 
-            # Handle pipes being a list, sync function, or async function
-            try:
-                if callable(function_module.pipes):
-                    if asyncio.iscoroutinefunction(function_module.pipes):
-                        sub_pipes = await function_module.pipes()
-                    else:
-                        sub_pipes = function_module.pipes()
-                else:
-                    sub_pipes = function_module.pipes
-            except Exception as e:
-                log.exception(e)
+            # Check if function is a manifold
+            if hasattr(function_module, "pipes"):
                 sub_pipes = []
 
-            log.debug(
-                f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
-            )
-
-            for p in sub_pipes:
-                sub_pipe_id = f'{pipe.id}.{p["id"]}'
-                sub_pipe_name = p["name"]
+                # Handle pipes being a list, sync function, or async function
+                try:
+                    if callable(function_module.pipes):
+                        if asyncio.iscoroutinefunction(function_module.pipes):
+                            sub_pipes = await function_module.pipes()
+                        else:
+                            sub_pipes = function_module.pipes()
+                    else:
+                        sub_pipes = function_module.pipes
+                except Exception as e:
+                    log.exception(e)
+                    sub_pipes = []
 
-                if hasattr(function_module, "name"):
-                    sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
+                log.debug(
+                    f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
+                )
 
-                pipe_flag = {"type": pipe.type}
+                for p in sub_pipes:
+                    sub_pipe_id = f'{pipe.id}.{p["id"]}'
+                    sub_pipe_name = p["name"]
+
+                    if hasattr(function_module, "name"):
+                        sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
+
+                    pipe_flag = {"type": pipe.type}
+
+                    pipe_models.append(
+                        {
+                            "id": sub_pipe_id,
+                            "name": sub_pipe_name,
+                            "object": "model",
+                            "created": pipe.created_at,
+                            "owned_by": "openai",
+                            "pipe": pipe_flag,
+                        }
+                    )
+            else:
+                pipe_flag = {"type": "pipe"}
+
+                log.debug(
+                    f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
+                )
 
                 pipe_models.append(
                     {
-                        "id": sub_pipe_id,
-                        "name": sub_pipe_name,
+                        "id": pipe.id,
+                        "name": pipe.name,
                         "object": "model",
                         "created": pipe.created_at,
                         "owned_by": "openai",
                         "pipe": pipe_flag,
                     }
                 )
-        else:
-            pipe_flag = {"type": "pipe"}
-
-            log.debug(
-                f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
-            )
-
-            pipe_models.append(
-                {
-                    "id": pipe.id,
-                    "name": pipe.name,
-                    "object": "model",
-                    "created": pipe.created_at,
-                    "owned_by": "openai",
-                    "pipe": pipe_flag,
-                }
-            )
+        except Exception as e:
+            log.exception(e)
+            continue
 
     return pipe_models
 

+ 10 - 0
backend/open_webui/routers/functions.py

@@ -148,6 +148,16 @@ async def sync_functions(
                 content=function.content,
             )
 
+            if hasattr(function_module, "Valves") and function.valves:
+                Valves = function_module.Valves
+                try:
+                    Valves(**{k: v for k, v in function.valves if v is not None})
+                except Exception as e:
+                    log.exception(
+                        f"Error validating valves for function {function.id}: {e}"
+                    )
+                    raise e
+
         return Functions.sync_functions(user.id, form_data.functions)
     except Exception as e:
         log.exception(f"Failed to load a function: {e}")