1
0
Эх сурвалжийг харах

new embedding.py added for handling openai and ollama embedding

henry 4 сар өмнө
parent
commit
e0769c6a1f

+ 124 - 0
backend/open_webui/utils/embeddings.py

@@ -0,0 +1,124 @@
+import random
+import logging
+import sys
+
+from fastapi import Request
+from open_webui.models.users import UserModel
+from open_webui.models.models import Models
+from open_webui.utils.models import check_model_access
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
+
+from open_webui.routers.openai import embeddings as openai_embeddings
+from open_webui.routers.ollama import embeddings as ollama_embeddings
+from open_webui.routers.pipelines import process_pipeline_inlet_filter
+
+from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
+from open_webui.utils.response import convert_response_ollama_to_openai
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+async def generate_embeddings(
+    request: Request,
+    form_data: dict,
+    user: UserModel,
+    bypass_filter: bool = False,
+):
+    """
+    Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
+
+    Args:
+        request (Request): The FastAPI request context.
+        form_data (dict): The input data sent to the endpoint.
+        user (UserModel): The authenticated user.
+        bypass_filter (bool): If True, disables access filtering (default False).
+
+    Returns:
+        dict: The embeddings response, following OpenAI API compatibility.
+    """
+    if BYPASS_MODEL_ACCESS_CONTROL:
+        bypass_filter = True
+
+    # Attach extra metadata from request.state if present
+    if hasattr(request.state, "metadata"):
+        if "metadata" not in form_data:
+            form_data["metadata"] = request.state.metadata
+        else:
+            form_data["metadata"] = {
+                **form_data["metadata"],
+                **request.state.metadata,
+            }
+
+    # If "direct" flag present, use only that model
+    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
+
+    model_id = form_data.get("model")
+    if model_id not in models:
+        raise Exception("Model not found")
+    model = models[model_id]
+
+    # Access filtering
+    if not getattr(request.state, "direct", False):
+        if not bypass_filter and user.role == "user":
+            check_model_access(user, model)
+
+    # Arena "meta-model": select a submodel at random
+    if model.get("owned_by") == "arena":
+        model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
+        filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
+        if model_ids and filter_mode == "exclude":
+            model_ids = [
+                m["id"]
+                for m in list(models.values())
+                if m.get("owned_by") != "arena" and m["id"] not in model_ids
+            ]
+        if isinstance(model_ids, list) and model_ids:
+            selected_model_id = random.choice(model_ids)
+        else:
+            model_ids = [
+                m["id"]
+                for m in list(models.values())
+                if m.get("owned_by") != "arena"
+            ]
+            selected_model_id = random.choice(model_ids)
+        inner_form = dict(form_data)
+        inner_form["model"] = selected_model_id
+        response = await generate_embeddings(
+            request, inner_form, user, bypass_filter=True
+        )
+        # Tag which concreted model was chosen
+        if isinstance(response, dict):
+            response = {
+                **response,
+                "selected_model_id": selected_model_id,
+            }
+        return response
+
+    # Pipeline/Function models
+    if model.get("pipe"):
+        # The pipeline handler should provide OpenAI-compatible schema
+        return await process_pipeline_inlet_filter(request, form_data, user, models)
+
+    # Ollama backend
+    if model.get("owned_by") == "ollama":
+        ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
+        response = await ollama_embeddings(
+            request=request,
+            form_data=ollama_payload,
+            user=user,
+        )
+        return convert_response_ollama_to_openai(response)
+
+    # Default: OpenAI or compatible backend
+    return await openai_embeddings(
+        request=request,
+        form_data=form_data,
+        user=user,
+    )