Browse Source

openai embeddings function modified

henry 3 weeks ago
parent
commit
3ddebefca2
2 changed files with 52 additions and 20 deletions
  1. 32 5
      backend/open_webui/main.py
  2. 20 15
      backend/open_webui/routers/openai.py

+ 32 - 5
backend/open_webui/main.py

@@ -411,6 +411,7 @@ from open_webui.utils.chat import (
     chat_completed as chat_completed_handler,
     chat_action as chat_action_handler,
 )
+from open_webui.utils.embeddings import generate_embeddings
 from open_webui.utils.middleware import process_chat_payload, process_chat_response
 from open_webui.utils.access_control import has_access
 
@@ -1363,11 +1364,6 @@ async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified
     return {"task_ids": task_ids}
 
 
-@app.post("/api/embeddings")
-async def api_embeddings(request: Request, user=Depends(get_verified_user)):
-    return await openai.generate_embeddings(request=request, user=user)
-
-
 ##################################
 #
 # Config Endpoints
@@ -1544,6 +1540,37 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
 async def get_app_changelog():
     return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
 
+##################################
+# Embeddings
+##################################
+
+@app.post("/api/embeddings")
+async def embeddings_endpoint(
+    request: Request,
+    form_data: dict,
+    user=Depends(get_verified_user)
+):
+    """
+    OpenAI-compatible embeddings endpoint.
+
+    This handler:
+      - Performs user/model checks and dispatches to the correct backend.
+      - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
+
+    Args:
+        request (Request): Request context.
+        form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
+        user (UserModel): Authenticated user.
+
+    Returns:
+        dict: OpenAI-compatible embeddings response.
+    """
+    # Make sure models are loaded in app state
+    if not request.app.state.MODELS:
+        await get_all_models(request, user=user)
+    # Use generic dispatcher in utils.embeddings
+    return await generate_embeddings(request, form_data, user)
+
 
 ############################
 # OAuth Login & Callback

+ 20 - 15
backend/open_webui/routers/openai.py

@@ -886,26 +886,36 @@ async def generate_chat_completion(
                 r.close()
             await session.close()
 
-@router.post("/embeddings")
-async def generate_embeddings(request: Request, user=Depends(get_verified_user)):
+async def embeddings(request: Request, form_data: dict, user):
     """
-    Call embeddings endpoint
+    Calls the embeddings endpoint for OpenAI-compatible providers.
+    
+    Args:
+        request (Request): The FastAPI request context.
+        form_data (dict): OpenAI-compatible embeddings payload.
+        user (UserModel): The authenticated user.
+    
+    Returns:
+        dict: OpenAI-compatible embeddings response.
     """
-
-    body = await request.body()
-
     idx = 0
+    # Prepare payload/body
+    body = json.dumps(form_data)
+    # Find correct backend url/key based on model
+    await get_all_models(request, user=user)
+    model_id = form_data.get("model")
+    models = request.app.state.OPENAI_MODELS
+    if model_id in models:
+        idx = models[model_id]["urlIdx"]
     url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
     key = request.app.state.config.OPENAI_API_KEYS[idx]
-
     r = None
     session = None
     streaming = False
-
     try:
         session = aiohttp.ClientSession(trust_env=True)
         r = await session.request(
-            method=request.method,
+            method="POST",
             url=f"{url}/embeddings",
             data=body,
             headers={
@@ -918,14 +928,11 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
                         "X-OpenWebUI-User-Email": user.email,
                         "X-OpenWebUI-User-Role": user.role,
                     }
-                    if ENABLE_FORWARD_USER_INFO_HEADERS
-                    else {}
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user else {}
                 ),
             },
         )
         r.raise_for_status()
-
-        # Check if response is SSE
         if "text/event-stream" in r.headers.get("Content-Type", ""):
             streaming = True
             return StreamingResponse(
@@ -939,10 +946,8 @@ async def generate_embeddings(request: Request, user=Depends(get_verified_user))
         else:
             response_data = await r.json()
             return response_data
-
     except Exception as e:
         log.exception(e)
-
         detail = None
         if r is not None:
             try: