|
@@ -0,0 +1,127 @@
|
|
|
+import random
|
|
|
+import logging
|
|
|
+import sys
|
|
|
+
|
|
|
+from fastapi import Request
|
|
|
+from open_webui.models.users import UserModel
|
|
|
+from open_webui.models.models import Models
|
|
|
+from open_webui.utils.models import check_model_access
|
|
|
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
|
|
+
|
|
|
+from open_webui.routers.openai import embeddings as openai_embeddings
|
|
|
+from open_webui.routers.ollama import embeddings as ollama_embeddings
|
|
|
+from open_webui.routers.ollama import GenerateEmbeddingsForm
|
|
|
+from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
|
|
+
|
|
|
+
|
|
|
+from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
|
|
+from open_webui.utils.response import convert_embedding_response_ollama_to_openai
|
|
|
+
|
|
|
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
|
+log = logging.getLogger(__name__)
|
|
|
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
+
|
|
|
+
|
|
|
+async def generate_embeddings(
|
|
|
+ request: Request,
|
|
|
+ form_data: dict,
|
|
|
+ user: UserModel,
|
|
|
+ bypass_filter: bool = False,
|
|
|
+):
|
|
|
+ """
|
|
|
+ Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request (Request): The FastAPI request context.
|
|
|
+ form_data (dict): The input data sent to the endpoint.
|
|
|
+ user (UserModel): The authenticated user.
|
|
|
+ bypass_filter (bool): If True, disables access filtering (default False).
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: The embeddings response, following OpenAI API compatibility.
|
|
|
+ """
|
|
|
+ if BYPASS_MODEL_ACCESS_CONTROL:
|
|
|
+ bypass_filter = True
|
|
|
+
|
|
|
+ # Attach extra metadata from request.state if present
|
|
|
+ if hasattr(request.state, "metadata"):
|
|
|
+ if "metadata" not in form_data:
|
|
|
+ form_data["metadata"] = request.state.metadata
|
|
|
+ else:
|
|
|
+ form_data["metadata"] = {
|
|
|
+ **form_data["metadata"],
|
|
|
+ **request.state.metadata,
|
|
|
+ }
|
|
|
+
|
|
|
+ # If "direct" flag present, use only that model
|
|
|
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
|
+ models = {
|
|
|
+ request.state.model["id"]: request.state.model,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ models = request.app.state.MODELS
|
|
|
+
|
|
|
+ model_id = form_data.get("model")
|
|
|
+ if model_id not in models:
|
|
|
+ raise Exception("Model not found")
|
|
|
+ model = models[model_id]
|
|
|
+
|
|
|
+ # Access filtering
|
|
|
+ if not getattr(request.state, "direct", False):
|
|
|
+ if not bypass_filter and user.role == "user":
|
|
|
+ check_model_access(user, model)
|
|
|
+
|
|
|
+ # Arena "meta-model": select a submodel at random
|
|
|
+ if model.get("owned_by") == "arena":
|
|
|
+ model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
|
+ filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
|
+ if model_ids and filter_mode == "exclude":
|
|
|
+ model_ids = [
|
|
|
+ m["id"]
|
|
|
+ for m in list(models.values())
|
|
|
+ if m.get("owned_by") != "arena" and m["id"] not in model_ids
|
|
|
+ ]
|
|
|
+ if isinstance(model_ids, list) and model_ids:
|
|
|
+ selected_model_id = random.choice(model_ids)
|
|
|
+ else:
|
|
|
+ model_ids = [
|
|
|
+ m["id"]
|
|
|
+ for m in list(models.values())
|
|
|
+ if m.get("owned_by") != "arena"
|
|
|
+ ]
|
|
|
+ selected_model_id = random.choice(model_ids)
|
|
|
+ inner_form = dict(form_data)
|
|
|
+ inner_form["model"] = selected_model_id
|
|
|
+ response = await generate_embeddings(
|
|
|
+ request, inner_form, user, bypass_filter=True
|
|
|
+ )
|
|
|
+ # Tag which concreted model was chosen
|
|
|
+ if isinstance(response, dict):
|
|
|
+ response = {
|
|
|
+ **response,
|
|
|
+ "selected_model_id": selected_model_id,
|
|
|
+ }
|
|
|
+ return response
|
|
|
+
|
|
|
+ # Pipeline/Function models
|
|
|
+ if model.get("pipe"):
|
|
|
+ # The pipeline handler should provide OpenAI-compatible schema
|
|
|
+ return await process_pipeline_inlet_filter(request, form_data, user, models)
|
|
|
+
|
|
|
+ # Ollama backend
|
|
|
+ if model.get("owned_by") == "ollama":
|
|
|
+ ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
|
|
+ form_obj = GenerateEmbeddingsForm(**ollama_payload)
|
|
|
+ response = await ollama_embeddings(
|
|
|
+ request=request,
|
|
|
+ form_data=form_obj,
|
|
|
+ user=user,
|
|
|
+ )
|
|
|
+ return convert_embedding_response_ollama_to_openai(response)
|
|
|
+
|
|
|
+ # Default: OpenAI or compatible backend
|
|
|
+ return await openai_embeddings(
|
|
|
+ request=request,
|
|
|
+ form_data=form_data,
|
|
|
+ user=user,
|
|
|
+ )
|