Timothy J. Baek 10 月之前
父節點
當前提交
e5895af7a0
共有 1 個文件被更改,包括 64 次插入64 次删除
  1. 64 64
      backend/main.py

+ 64 - 64
backend/main.py

@@ -212,6 +212,70 @@ origins = ["*"]
 ##################################
 
 
+async def get_body_and_model_and_user(request):
+    # Read the original request body
+    body = await request.body()
+    body_str = body.decode("utf-8")
+    body = json.loads(body_str) if body_str else {}
+
+    model_id = body["model"]
+    if model_id not in app.state.MODELS:
+        raise "Model not found"
+    model = app.state.MODELS[model_id]
+
+    user = get_current_user(
+        request,
+        get_http_authorization_cred(request.headers.get("Authorization")),
+    )
+
+    return body, model, user
+
+
+def get_task_model_id(default_model_id):
+    # Set the task model
+    task_model_id = default_model_id
+    # Check if the user has a custom task model and use that model
+    if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
+        if (
+            app.state.config.TASK_MODEL
+            and app.state.config.TASK_MODEL in app.state.MODELS
+        ):
+            task_model_id = app.state.config.TASK_MODEL
+    else:
+        if (
+            app.state.config.TASK_MODEL_EXTERNAL
+            and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
+        ):
+            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+
+    return task_model_id
+
+
+def get_filter_function_ids(model):
+    def get_priority(function_id):
+        function = Functions.get_function_by_id(function_id)
+        if function is not None and hasattr(function, "valves"):
+            return (function.valves if function.valves else {}).get("priority", 0)
+        return 0
+
+    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
+    if "info" in model and "meta" in model["info"]:
+        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+        filter_ids = list(set(filter_ids))
+
+    enabled_filter_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("filter", active_only=True)
+    ]
+
+    filter_ids = [
+        filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+    ]
+
+    filter_ids.sort(key=get_priority)
+    return filter_ids
+
+
 async def get_function_call_response(
     messages, files, tool_id, template, task_model_id, user, model
 ):
@@ -373,51 +437,6 @@ async def get_function_call_response(
     return None, None, False
 
 
-def get_task_model_id(default_model_id):
-    # Set the task model
-    task_model_id = default_model_id
-    # Check if the user has a custom task model and use that model
-    if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
-        if (
-            app.state.config.TASK_MODEL
-            and app.state.config.TASK_MODEL in app.state.MODELS
-        ):
-            task_model_id = app.state.config.TASK_MODEL
-    else:
-        if (
-            app.state.config.TASK_MODEL_EXTERNAL
-            and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
-        ):
-            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-
-    return task_model_id
-
-
-def get_filter_function_ids(model):
-    def get_priority(function_id):
-        function = Functions.get_function_by_id(function_id)
-        if function is not None and hasattr(function, "valves"):
-            return (function.valves if function.valves else {}).get("priority", 0)
-        return 0
-
-    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
-    if "info" in model and "meta" in model["info"]:
-        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
-        filter_ids = list(set(filter_ids))
-
-    enabled_filter_ids = [
-        function.id
-        for function in Functions.get_functions_by_type("filter", active_only=True)
-    ]
-
-    filter_ids = [
-        filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
-    ]
-
-    filter_ids.sort(key=get_priority)
-    return filter_ids
-
-
 async def chat_completion_functions_handler(body, model, user):
     skip_files = None
 
@@ -579,25 +598,6 @@ async def chat_completion_files_handler(body):
     }
 
 
-async def get_body_and_model_and_user(request):
-    # Read the original request body
-    body = await request.body()
-    body_str = body.decode("utf-8")
-    body = json.loads(body_str) if body_str else {}
-
-    model_id = body["model"]
-    if model_id not in app.state.MODELS:
-        raise "Model not found"
-    model = app.state.MODELS[model_id]
-
-    user = get_current_user(
-        request,
-        get_http_authorization_cred(request.headers.get("Authorization")),
-    )
-
-    return body, model, user
-
-
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         if request.method == "POST" and any(