Browse Source

Merge pull request #14667 from hdnh2006/main

feat: OpenAI-compatible `/api/embeddings` endpoint with provider-agnostic OpenWebUI architecture
Tim Jaeryang Baek 3 weeks ago
parent
commit
14e158fde9

+ 32 - 0
backend/open_webui/main.py

@@ -413,6 +413,7 @@ from open_webui.utils.chat import (
     chat_completed as chat_completed_handler,
     chat_action as chat_action_handler,
 )
+from open_webui.utils.embeddings import generate_embeddings
 from open_webui.utils.middleware import process_chat_payload, process_chat_response
 from open_webui.utils.access_control import has_access
 
@@ -1545,6 +1546,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
 async def get_app_changelog():
     return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
 
+##################################
+# Embeddings
+##################################
+
+@app.post("/api/embeddings")
+async def embeddings_endpoint(
+    request: Request,
+    form_data: dict,
+    user=Depends(get_verified_user)
+):
+    """
+    OpenAI-compatible embeddings endpoint.
+
+    This handler:
+      - Performs user/model checks and dispatches to the correct backend.
+      - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
+
+    Args:
+        request (Request): Request context.
+        form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
+        user (UserModel): Authenticated user.
+
+    Returns:
+        dict: OpenAI-compatible embeddings response.
+    """
+    # Make sure models are loaded in app state
+    if not request.app.state.MODELS:
+        await get_all_models(request, user=user)
+    # Use generic dispatcher in utils.embeddings
+    return await generate_embeddings(request, form_data, user)
+
 
 ############################
 # OAuth Login & Callback

+ 79 - 0
backend/open_webui/routers/openai.py

@@ -886,6 +886,85 @@ async def generate_chat_completion(
                 r.close()
             await session.close()
 
+async def embeddings(request: Request, form_data: dict, user):
+    """
+    Calls the embeddings endpoint for OpenAI-compatible providers.
+    
+    Args:
+        request (Request): The FastAPI request context.
+        form_data (dict): OpenAI-compatible embeddings payload.
+        user (UserModel): The authenticated user.
+    
+    Returns:
+        dict: OpenAI-compatible embeddings response.
+    """
+    idx = 0
+    # Prepare payload/body
+    body = json.dumps(form_data)
+    # Find correct backend url/key based on model
+    await get_all_models(request, user=user)
+    model_id = form_data.get("model")
+    models = request.app.state.OPENAI_MODELS
+    if model_id in models:
+        idx = models[model_id]["urlIdx"]
+    url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+    key = request.app.state.config.OPENAI_API_KEYS[idx]
+    r = None
+    session = None
+    streaming = False
+    try:
+        session = aiohttp.ClientSession(trust_env=True)
+        r = await session.request(
+            method="POST",
+            url=f"{url}/embeddings",
+            data=body,
+            headers={
+                "Authorization": f"Bearer {key}",
+                "Content-Type": "application/json",
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user else {}
+                ),
+            },
+        )
+        r.raise_for_status()
+        if "text/event-stream" in r.headers.get("Content-Type", ""):
+            streaming = True
+            return StreamingResponse(
+                r.content,
+                status_code=r.status,
+                headers=dict(r.headers),
+                background=BackgroundTask(
+                    cleanup_response, response=r, session=session
+                ),
+            )
+        else:
+            response_data = await r.json()
+            return response_data
+    except Exception as e:
+        log.exception(e)
+        detail = None
+        if r is not None:
+            try:
+                res = await r.json()
+                if "error" in res:
+                    detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
+            except Exception:
+                detail = f"External: {e}"
+        raise HTTPException(
+            status_code=r.status if r else 500,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
+        )
+    finally:
+        if not streaming and session:
+            if r:
+                r.close()
+            await session.close()
 
 @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_verified_user)):

+ 127 - 0
backend/open_webui/utils/embeddings.py

@@ -0,0 +1,127 @@
+import random
+import logging
+import sys
+
+from fastapi import Request
+from open_webui.models.users import UserModel
+from open_webui.models.models import Models
+from open_webui.utils.models import check_model_access
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
+
+from open_webui.routers.openai import embeddings as openai_embeddings
+from open_webui.routers.ollama import embeddings as ollama_embeddings
+from open_webui.routers.ollama import GenerateEmbeddingsForm
+from open_webui.routers.pipelines import process_pipeline_inlet_filter
+
+
+from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
+from open_webui.utils.response import convert_embedding_response_ollama_to_openai
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+async def generate_embeddings(
+    request: Request,
+    form_data: dict,
+    user: UserModel,
+    bypass_filter: bool = False,
+):
+    """
+    Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
+
+    Args:
+        request (Request): The FastAPI request context.
+        form_data (dict): The input data sent to the endpoint.
+        user (UserModel): The authenticated user.
+        bypass_filter (bool): If True, disables access filtering (default False).
+
+    Returns:
+        dict: The embeddings response, following OpenAI API compatibility.
+    """
+    if BYPASS_MODEL_ACCESS_CONTROL:
+        bypass_filter = True
+
+    # Attach extra metadata from request.state if present
+    if hasattr(request.state, "metadata"):
+        if "metadata" not in form_data:
+            form_data["metadata"] = request.state.metadata
+        else:
+            form_data["metadata"] = {
+                **form_data["metadata"],
+                **request.state.metadata,
+            }
+
+    # If "direct" flag present, use only that model
+    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
+
+    model_id = form_data.get("model")
+    if model_id not in models:
+        raise Exception("Model not found")
+    model = models[model_id]
+
+    # Access filtering
+    if not getattr(request.state, "direct", False):
+        if not bypass_filter and user.role == "user":
+            check_model_access(user, model)
+
+    # Arena "meta-model": select a submodel at random
+    if model.get("owned_by") == "arena":
+        model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
+        filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
+        if model_ids and filter_mode == "exclude":
+            model_ids = [
+                m["id"]
+                for m in list(models.values())
+                if m.get("owned_by") != "arena" and m["id"] not in model_ids
+            ]
+        if isinstance(model_ids, list) and model_ids:
+            selected_model_id = random.choice(model_ids)
+        else:
+            model_ids = [
+                m["id"]
+                for m in list(models.values())
+                if m.get("owned_by") != "arena"
+            ]
+            selected_model_id = random.choice(model_ids)
+        inner_form = dict(form_data)
+        inner_form["model"] = selected_model_id
+        response = await generate_embeddings(
+            request, inner_form, user, bypass_filter=True
+        )
+        # Tag which concreted model was chosen
+        if isinstance(response, dict):
+            response = {
+                **response,
+                "selected_model_id": selected_model_id,
+            }
+        return response
+
+    # Pipeline/Function models
+    if model.get("pipe"):
+        # The pipeline handler should provide OpenAI-compatible schema
+        return await process_pipeline_inlet_filter(request, form_data, user, models)
+
+    # Ollama backend
+    if model.get("owned_by") == "ollama":
+        ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
+        form_obj = GenerateEmbeddingsForm(**ollama_payload)
+        response = await ollama_embeddings(
+            request=request,
+            form_data=form_obj,
+            user=user,
+        )
+        return convert_embedding_response_ollama_to_openai(response)
+
+    # Default: OpenAI or compatible backend
+    return await openai_embeddings(
+        request=request,
+        form_data=form_data,
+        user=user,
+    )

+ 31 - 0
backend/open_webui/utils/payload.py

@@ -329,3 +329,34 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
             ollama_payload["format"] = format
 
     return ollama_payload
+
+
+def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
+    """
+    Convert an embeddings request payload from OpenAI format to Ollama format.
+
+    Args:
+        openai_payload (dict): The original payload designed for OpenAI API usage.
+
+    Returns:
+        dict: A payload compatible with the Ollama API embeddings endpoint.
+    """
+    ollama_payload = {
+        "model": openai_payload.get("model")
+    }
+    input_value = openai_payload.get("input")
+
+    # Ollama expects 'input' as a list, and 'prompt' as a single string.
+    if isinstance(input_value, list):
+        ollama_payload["input"] = input_value
+        ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
+    else:
+        ollama_payload["input"] = [input_value]
+        ollama_payload["prompt"] = str(input_value)
+
+    # Optionally forward other fields if present
+    for optional_key in ("options", "truncate", "keep_alive"):
+        if optional_key in openai_payload:
+            ollama_payload[optional_key] = openai_payload[optional_key]
+
+    return ollama_payload

+ 52 - 0
backend/open_webui/utils/response.py

@@ -125,3 +125,55 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
         yield line
 
     yield "data: [DONE]\n\n"
+
+def convert_embedding_response_ollama_to_openai(response) -> dict:
+    """
+    Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
+
+    Args:
+        response (dict): The response from the Ollama API, 
+            e.g. {"embedding": [...], "model": "..."}
+            or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
+
+    Returns:
+        dict: Response adapted to OpenAI's embeddings API format.
+            e.g. {
+                "object": "list",
+                "data": [
+                    {"object": "embedding", "embedding": [...], "index": 0},
+                    ...
+                ],
+                "model": "...",
+            }
+    """
+    # Ollama batch-style output
+    if isinstance(response, dict) and "embeddings" in response:
+        openai_data = []
+        for i, emb in enumerate(response["embeddings"]):
+            openai_data.append({
+                "object": "embedding",
+                "embedding": emb.get("embedding"),
+                "index": emb.get("index", i),
+            })
+        return {
+            "object": "list",
+            "data": openai_data,
+            "model": response.get("model"),
+        }
+    # Ollama single output
+    elif isinstance(response, dict) and "embedding" in response:
+        return {
+            "object": "list",
+            "data": [{
+                "object": "embedding",
+                "embedding": response["embedding"],
+                "index": 0,
+            }],
+            "model": response.get("model"),
+        }
+    # Already OpenAI-compatible?
+    elif isinstance(response, dict) and "data" in response and isinstance(response["data"], list):
+        return response
+
+    # Fallback: return as is if unrecognized
+    return response