Jelajahi Sumber

refac: embeddings endpoint

Timothy Jaeryang Baek 4 bulan lalu
induk
melakukan
ab36b8aeae
2 mengubah file dengan 38 tambahan dan 75 penghapusan
  1. 31 31
      backend/open_webui/main.py
  2. 7 44
      backend/open_webui/utils/embeddings.py

+ 31 - 31
backend/open_webui/main.py

@@ -1208,6 +1208,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
     return {"data": models}
     return {"data": models}
 
 
 
 
+##################################
+# Embeddings
+##################################
+
+
+@app.post("/api/embeddings")
+async def embeddings(
+    request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+    """
+    OpenAI-compatible embeddings endpoint.
+
+    This handler:
+      - Performs user/model checks and dispatches to the correct backend.
+      - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
+
+    Args:
+        request (Request): Request context.
+        form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
+        user (UserModel): Authenticated user.
+
+    Returns:
+        dict: OpenAI-compatible embeddings response.
+    """
+    # Make sure models are loaded in app state
+    if not request.app.state.MODELS:
+        await get_all_models(request, user=user)
+    # Use generic dispatcher in utils.embeddings
+    return await generate_embeddings(request, form_data, user)
+
+
 @app.post("/api/chat/completions")
 @app.post("/api/chat/completions")
 async def chat_completion(
 async def chat_completion(
     request: Request,
     request: Request,
@@ -1550,37 +1581,6 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
 async def get_app_changelog():
 async def get_app_changelog():
     return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
     return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
 
 
-##################################
-# Embeddings
-##################################
-
-@app.post("/api/embeddings")
-async def embeddings_endpoint(
-    request: Request,
-    form_data: dict,
-    user=Depends(get_verified_user)
-):
-    """
-    OpenAI-compatible embeddings endpoint.
-
-    This handler:
-      - Performs user/model checks and dispatches to the correct backend.
-      - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
-
-    Args:
-        request (Request): Request context.
-        form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
-        user (UserModel): Authenticated user.
-
-    Returns:
-        dict: OpenAI-compatible embeddings response.
-    """
-    # Make sure models are loaded in app state
-    if not request.app.state.MODELS:
-        await get_all_models(request, user=user)
-    # Use generic dispatcher in utils.embeddings
-    return await generate_embeddings(request, form_data, user)
-
 
 
 ############################
 ############################
 # OAuth Login & Callback
 # OAuth Login & Callback

+ 7 - 44
backend/open_webui/utils/embeddings.py

@@ -9,9 +9,10 @@ from open_webui.utils.models import check_model_access
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
 
 
 from open_webui.routers.openai import embeddings as openai_embeddings
 from open_webui.routers.openai import embeddings as openai_embeddings
-from open_webui.routers.ollama import embeddings as ollama_embeddings
-from open_webui.routers.ollama import GenerateEmbeddingsForm
-from open_webui.routers.pipelines import process_pipeline_inlet_filter
+from open_webui.routers.ollama import (
+    embeddings as ollama_embeddings,
+    GenerateEmbeddingsForm,
+)
 
 
 
 
 from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
 from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
@@ -29,7 +30,7 @@ async def generate_embeddings(
     bypass_filter: bool = False,
     bypass_filter: bool = False,
 ):
 ):
     """
     """
-    Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
+    Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
 
 
     Args:
     Args:
         request (Request): The FastAPI request context.
         request (Request): The FastAPI request context.
@@ -71,50 +72,12 @@ async def generate_embeddings(
         if not bypass_filter and user.role == "user":
         if not bypass_filter and user.role == "user":
             check_model_access(user, model)
             check_model_access(user, model)
 
 
-    # Arena "meta-model": select a submodel at random
-    if model.get("owned_by") == "arena":
-        model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
-        filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
-        if model_ids and filter_mode == "exclude":
-            model_ids = [
-                m["id"]
-                for m in list(models.values())
-                if m.get("owned_by") != "arena" and m["id"] not in model_ids
-            ]
-        if isinstance(model_ids, list) and model_ids:
-            selected_model_id = random.choice(model_ids)
-        else:
-            model_ids = [
-                m["id"]
-                for m in list(models.values())
-                if m.get("owned_by") != "arena"
-            ]
-            selected_model_id = random.choice(model_ids)
-        inner_form = dict(form_data)
-        inner_form["model"] = selected_model_id
-        response = await generate_embeddings(
-            request, inner_form, user, bypass_filter=True
-        )
-        # Tag which concreted model was chosen
-        if isinstance(response, dict):
-            response = {
-                **response,
-                "selected_model_id": selected_model_id,
-            }
-        return response
-
-    # Pipeline/Function models
-    if model.get("pipe"):
-        # The pipeline handler should provide OpenAI-compatible schema
-        return await process_pipeline_inlet_filter(request, form_data, user, models)
-
     # Ollama backend
     # Ollama backend
     if model.get("owned_by") == "ollama":
     if model.get("owned_by") == "ollama":
         ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
         ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
-        form_obj = GenerateEmbeddingsForm(**ollama_payload)
         response = await ollama_embeddings(
         response = await ollama_embeddings(
             request=request,
             request=request,
-            form_data=form_obj,
+            form_data=GenerateEmbeddingsForm(**ollama_payload),
             user=user,
             user=user,
         )
         )
         return convert_embedding_response_ollama_to_openai(response)
         return convert_embedding_response_ollama_to_openai(response)
@@ -124,4 +87,4 @@ async def generate_embeddings(
         request=request,
         request=request,
         form_data=form_data,
         form_data=form_data,
         user=user,
         user=user,
-    )
+    )