Browse Source

embeddings function added 100% OpenAI compatible

hdnh2006 7 months ago
parent
commit
0afe972bc6
2 changed files with 79 additions and 0 deletions
  1. 5 0
      backend/open_webui/main.py
  2. 74 0
      backend/open_webui/routers/openai.py

+ 5 - 0
backend/open_webui/main.py

@@ -1038,6 +1038,11 @@ async def list_tasks_endpoint(user=Depends(get_verified_user)):
     return {"tasks": list_tasks()}  # Use the function from tasks.py
 
 
+@app.post("/api/embeddings")
+async def api_embeddings(request: Request, user=Depends(get_verified_user)):
+    return await openai.generate_embeddings(request=request, user=user)
+
+
 ##################################
 #
 # Config Endpoints

+ 74 - 0
backend/open_webui/routers/openai.py

@@ -715,6 +715,80 @@ async def generate_chat_completion(
                 r.close()
             await session.close()
 
+@router.post("/embeddings")
+async def generate_embeddings(request: Request, user=Depends(get_verified_user)):
+    """
+    Call embeddings endpoint
+    """
+
+    body = await request.body()
+
+    idx = 0
+    url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+    key = request.app.state.config.OPENAI_API_KEYS[idx]
+
+    r = None
+    session = None
+    streaming = False
+
+    try:
+        session = aiohttp.ClientSession(trust_env=True)
+        r = await session.request(
+            method=request.method,
+            url=f"{url}/embeddings",
+            data=body,
+            headers={
+                "Authorization": f"Bearer {key}",
+                "Content-Type": "application/json",
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS
+                    else {}
+                ),
+            },
+        )
+        r.raise_for_status()
+
+        # Check if response is SSE
+        if "text/event-stream" in r.headers.get("Content-Type", ""):
+            streaming = True
+            return StreamingResponse(
+                r.content,
+                status_code=r.status,
+                headers=dict(r.headers),
+                background=BackgroundTask(
+                    cleanup_response, response=r, session=session
+                ),
+            )
+        else:
+            response_data = await r.json()
+            return response_data
+
+    except Exception as e:
+        log.exception(e)
+
+        detail = None
+        if r is not None:
+            try:
+                res = await r.json()
+                if "error" in res:
+                    detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
+            except Exception:
+                detail = f"External: {e}"
+        raise HTTPException(
+            status_code=r.status if r else 500,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
+        )
+    finally:
+        if not streaming and session:
+            if r:
+                r.close()
+            await session.close()
 
 @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_verified_user)):