|
@@ -212,6 +212,70 @@ origins = ["*"]
|
|
##################################
|
|
##################################
|
|
|
|
|
|
|
|
|
|
|
|
+async def get_body_and_model_and_user(request):
|
|
|
|
+ # Read the original request body
|
|
|
|
+ body = await request.body()
|
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
|
+ body = json.loads(body_str) if body_str else {}
|
|
|
|
+
|
|
|
|
+ model_id = body["model"]
|
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
|
+ raise "Model not found"
|
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
|
+
|
|
|
|
+ user = get_current_user(
|
|
|
|
+ request,
|
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return body, model, user
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_task_model_id(default_model_id):
|
|
|
|
+ # Set the task model
|
|
|
|
+ task_model_id = default_model_id
|
|
|
|
+ # Check if the user has a custom task model and use that model
|
|
|
|
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
|
+ if (
|
|
|
|
+ app.state.config.TASK_MODEL
|
|
|
|
+ and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
|
+ ):
|
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
|
+ else:
|
|
|
|
+ if (
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
+ and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
|
+ ):
|
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
+
|
|
|
|
+ return task_model_id
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_filter_function_ids(model):
|
|
|
|
+ def get_priority(function_id):
|
|
|
|
+ function = Functions.get_function_by_id(function_id)
|
|
|
|
+ if function is not None and hasattr(function, "valves"):
|
|
|
|
+ return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
|
+ return 0
|
|
|
|
+
|
|
|
|
+ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
|
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
|
+ filter_ids = list(set(filter_ids))
|
|
|
|
+
|
|
|
|
+ enabled_filter_ids = [
|
|
|
|
+ function.id
|
|
|
|
+ for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ filter_ids = [
|
|
|
|
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ filter_ids.sort(key=get_priority)
|
|
|
|
+ return filter_ids
|
|
|
|
+
|
|
|
|
+
|
|
async def get_function_call_response(
|
|
async def get_function_call_response(
|
|
messages, files, tool_id, template, task_model_id, user, model
|
|
messages, files, tool_id, template, task_model_id, user, model
|
|
):
|
|
):
|
|
@@ -373,51 +437,6 @@ async def get_function_call_response(
|
|
return None, None, False
|
|
return None, None, False
|
|
|
|
|
|
|
|
|
|
-def get_task_model_id(default_model_id):
|
|
|
|
- # Set the task model
|
|
|
|
- task_model_id = default_model_id
|
|
|
|
- # Check if the user has a custom task model and use that model
|
|
|
|
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
|
- if (
|
|
|
|
- app.state.config.TASK_MODEL
|
|
|
|
- and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
|
- ):
|
|
|
|
- task_model_id = app.state.config.TASK_MODEL
|
|
|
|
- else:
|
|
|
|
- if (
|
|
|
|
- app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
- and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
|
- ):
|
|
|
|
- task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
-
|
|
|
|
- return task_model_id
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def get_filter_function_ids(model):
|
|
|
|
- def get_priority(function_id):
|
|
|
|
- function = Functions.get_function_by_id(function_id)
|
|
|
|
- if function is not None and hasattr(function, "valves"):
|
|
|
|
- return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
|
- return 0
|
|
|
|
-
|
|
|
|
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
|
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
|
- filter_ids = list(set(filter_ids))
|
|
|
|
-
|
|
|
|
- enabled_filter_ids = [
|
|
|
|
- function.id
|
|
|
|
- for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- filter_ids = [
|
|
|
|
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- filter_ids.sort(key=get_priority)
|
|
|
|
- return filter_ids
|
|
|
|
-
|
|
|
|
-
|
|
|
|
async def chat_completion_functions_handler(body, model, user):
|
|
async def chat_completion_functions_handler(body, model, user):
|
|
skip_files = None
|
|
skip_files = None
|
|
|
|
|
|
@@ -579,25 +598,6 @@ async def chat_completion_files_handler(body):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-async def get_body_and_model_and_user(request):
|
|
|
|
- # Read the original request body
|
|
|
|
- body = await request.body()
|
|
|
|
- body_str = body.decode("utf-8")
|
|
|
|
- body = json.loads(body_str) if body_str else {}
|
|
|
|
-
|
|
|
|
- model_id = body["model"]
|
|
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
- raise "Model not found"
|
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
-
|
|
|
|
- user = get_current_user(
|
|
|
|
- request,
|
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- return body, model, user
|
|
|
|
-
|
|
|
|
-
|
|
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if request.method == "POST" and any(
|
|
if request.method == "POST" and any(
|