Browse Source

feat: toggle filter middleware

Timothy Jaeryang Baek 4 months ago
parent
commit
1f38350128

+ 1 - 0
backend/open_webui/main.py

@@ -1186,6 +1186,7 @@ async def chat_completion(
             "chat_id": form_data.pop("chat_id", None),
             "message_id": form_data.pop("id", None),
             "session_id": form_data.pop("session_id", None),
+            "filter_ids": form_data.pop("filter_ids", None),
             "tool_ids": form_data.get("tool_ids", None),
             "tool_servers": form_data.pop("tool_servers", None),
             "files": form_data.get("files", None),

+ 1 - 4
backend/open_webui/routers/tasks.py

@@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.constants import TASKS
 
 from open_webui.routers.pipelines import process_pipeline_inlet_filter
-from open_webui.utils.filter import (
-    get_sorted_filter_ids,
-    process_filter_functions,
-)
+
 from open_webui.utils.task import get_task_model_id
 
 from open_webui.config import (

+ 1 - 1
backend/open_webui/utils/chat.py

@@ -330,7 +330,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
     try:
         filter_functions = [
             Functions.get_function_by_id(filter_id)
-            for filter_id in get_sorted_filter_ids(model)
+            for filter_id in get_sorted_filter_ids(request, model)
         ]
 
         result, _ = await process_filter_functions(

+ 27 - 10
backend/open_webui/utils/filter.py

@@ -9,7 +9,20 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
-def get_sorted_filter_ids(model: dict):
+def get_function_module(request, function_id):
+    """
+    Get the function module by its ID.
+    """
+    if function_id in request.app.state.FUNCTIONS:
+        function_module = request.app.state.FUNCTIONS[function_id]
+    else:
+        function_module, _, _ = load_function_module_by_id(function_id)
+        request.app.state.FUNCTIONS[function_id] = function_module
+
+    return function_module
+
+
+def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         if function is not None:
@@ -21,14 +34,23 @@ def get_sorted_filter_ids(model: dict):
     if "info" in model and "meta" in model["info"]:
         filter_ids.extend(model["info"]["meta"].get("filterIds", []))
         filter_ids = list(set(filter_ids))
-
-    enabled_filter_ids = [
+    active_filter_ids = [
         function.id
         for function in Functions.get_functions_by_type("filter", active_only=True)
     ]
 
-    filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
+    for filter_id in active_filter_ids:
+        function_module = get_function_module(request, filter_id)
+
+        if getattr(function_module, "toggle", None) and (
+            filter_id not in enabled_filter_ids
+        ):
+            active_filter_ids.remove(filter_id)
+            continue
+
+    filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
     filter_ids.sort(key=get_priority)
+
     return filter_ids
 
 
@@ -43,12 +65,7 @@ async def process_filter_functions(
         if not filter:
             continue
 
-        if filter_id in request.app.state.FUNCTIONS:
-            function_module = request.app.state.FUNCTIONS[filter_id]
-        else:
-            function_module, _, _ = load_function_module_by_id(filter_id)
-            request.app.state.FUNCTIONS[filter_id] = function_module
-
+        function_module = get_function_module(request, filter_id)
         # Prepare handler function
         handler = getattr(function_module, filter_type, None)
         if not handler:

+ 7 - 2
backend/open_webui/utils/middleware.py

@@ -754,9 +754,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
         raise e
 
     try:
+
         filter_functions = [
             Functions.get_function_by_id(filter_id)
-            for filter_id in get_sorted_filter_ids(model)
+            for filter_id in get_sorted_filter_ids(
+                request, model, metadata.get("filter_ids", [])
+            )
         ]
 
         form_data, flags = await process_filter_functions(
@@ -1188,7 +1191,9 @@ async def process_chat_response(
     }
     filter_functions = [
         Functions.get_function_by_id(filter_id)
-        for filter_id in get_sorted_filter_ids(model)
+        for filter_id in get_sorted_filter_ids(
+            request, model, metadata.get("filter_ids", [])
+        )
     ]
 
     # Streaming response

+ 1 - 0
src/lib/components/chat/Chat.svelte

@@ -1635,6 +1635,7 @@
 				},
 
 				files: (files?.length ?? 0) > 0 ? files : undefined,
+
 				filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
 				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 				tool_servers: $toolServers,