Explorar o código

refac: oauth auth type in openai connection

Timothy Jaeryang Baek hai 4 semanas
pai
achega
2b2d123531

+ 113 - 109
backend/open_webui/routers/openai.py

@@ -119,6 +119,74 @@ def openai_reasoning_model_handler(payload):
     return payload
 
 
+def get_headers_and_cookies(
+    request: Request,
+    url,
+    key=None,
+    config=None,
+    metadata: Optional[dict] = None,
+    user: UserModel = None,
+):
+    cookies = {}
+    headers = {
+        "Content-Type": "application/json",
+        **(
+            {
+                "HTTP-Referer": "https://openwebui.com/",
+                "X-Title": "Open WebUI",
+            }
+            if "openrouter.ai" in url
+            else {}
+        ),
+        **(
+            {
+                "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
+                "X-OpenWebUI-User-Id": user.id,
+                "X-OpenWebUI-User-Email": user.email,
+                "X-OpenWebUI-User-Role": user.role,
+                **(
+                    {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
+                    if metadata and metadata.get("chat_id")
+                    else {}
+                ),
+            }
+            if ENABLE_FORWARD_USER_INFO_HEADERS
+            else {}
+        ),
+    }
+
+    token = None
+    auth_type = config.get("auth_type")
+
+    if auth_type == "bearer" or auth_type is None:
+        # Default to bearer if not specified
+        token = f"{key}"
+    elif auth_type == "none":
+        token = None
+    elif auth_type == "session":
+        cookies = request.cookies
+        token = request.state.token.credentials
+    elif auth_type == "oauth":
+        cookies = request.cookies
+
+        oauth_token = None
+        try:
+            oauth_token = request.app.state.oauth_manager.get_oauth_token(
+                user.id,
+                request.cookies.get("oauth_session_id", None),
+            )
+        except Exception as e:
+            log.error(f"Error getting OAuth token: {e}")
+
+        if oauth_token:
+            token = f"{oauth_token.get('access_token', '')}"
+
+    if token:
+        headers["Authorization"] = f"Bearer {token}"
+
+    return headers, cookies
+
+
 ##########################################
 #
 # API routes
@@ -210,34 +278,23 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             return FileResponse(file_path)
 
         url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
+        key = request.app.state.config.OPENAI_API_KEYS[idx]
+        api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
+            str(idx),
+            request.app.state.config.OPENAI_API_CONFIGS.get(url, {}),  # Legacy support
+        )
+
+        headers, cookies = get_headers_and_cookies(
+            request, url, key, api_config, user=user
+        )
 
         r = None
         try:
             r = requests.post(
                 url=f"{url}/audio/speech",
                 data=body,
-                headers={
-                    "Content-Type": "application/json",
-                    "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
-                    **(
-                        {
-                            "HTTP-Referer": "https://openwebui.com/",
-                            "X-Title": "Open WebUI",
-                        }
-                        if "openrouter.ai" in url
-                        else {}
-                    ),
-                    **(
-                        {
-                            "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
-                            "X-OpenWebUI-User-Id": user.id,
-                            "X-OpenWebUI-User-Email": user.email,
-                            "X-OpenWebUI-User-Role": user.role,
-                        }
-                        if ENABLE_FORWARD_USER_INFO_HEADERS
-                        else {}
-                    ),
-                },
+                headers=headers,
+                cookies=cookies,
                 stream=True,
             )
 
@@ -401,7 +458,10 @@ async def get_filtered_models(models, user):
     return filtered_models
 
 
-@cached(ttl=MODELS_CACHE_TTL, key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models")
+@cached(
+    ttl=MODELS_CACHE_TTL,
+    key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models",
+)
 async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
     log.info("get_all_models()")
 
@@ -489,19 +549,9 @@ async def get_models(
             timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
         ) as session:
             try:
-                headers = {
-                    "Content-Type": "application/json",
-                    **(
-                        {
-                            "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
-                            "X-OpenWebUI-User-Id": user.id,
-                            "X-OpenWebUI-User-Email": user.email,
-                            "X-OpenWebUI-User-Role": user.role,
-                        }
-                        if ENABLE_FORWARD_USER_INFO_HEADERS
-                        else {}
-                    ),
-                }
+                headers, cookies = get_headers_and_cookies(
+                    request, url, key, api_config, user=user
+                )
 
                 if api_config.get("azure", False):
                     models = {
@@ -509,11 +559,10 @@ async def get_models(
                         "object": "list",
                     }
                 else:
-                    headers["Authorization"] = f"Bearer {key}"
-
                     async with session.get(
                         f"{url}/models",
                         headers=headers,
+                        cookies=cookies,
                         ssl=AIOHTTP_CLIENT_SESSION_SSL,
                     ) as r:
                         if r.status != 200:
@@ -572,7 +621,9 @@ class ConnectionVerificationForm(BaseModel):
 
 @router.post("/verify")
 async def verify_connection(
-    form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
+    request: Request,
+    form_data: ConnectionVerificationForm,
+    user=Depends(get_admin_user),
 ):
     url = form_data.url
     key = form_data.key
@@ -584,19 +635,9 @@ async def verify_connection(
         timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
     ) as session:
         try:
-            headers = {
-                "Content-Type": "application/json",
-                **(
-                    {
-                        "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
-                        "X-OpenWebUI-User-Id": user.id,
-                        "X-OpenWebUI-User-Email": user.email,
-                        "X-OpenWebUI-User-Role": user.role,
-                    }
-                    if ENABLE_FORWARD_USER_INFO_HEADERS
-                    else {}
-                ),
-            }
+            headers, cookies = get_headers_and_cookies(
+                request, url, key, api_config, user=user
+            )
 
             if api_config.get("azure", False):
                 headers["api-key"] = key
@@ -605,6 +646,7 @@ async def verify_connection(
                 async with session.get(
                     url=f"{url}/openai/models?api-version={api_version}",
                     headers=headers,
+                    cookies=cookies,
                     ssl=AIOHTTP_CLIENT_SESSION_SSL,
                 ) as r:
                     try:
@@ -624,11 +666,10 @@ async def verify_connection(
 
                     return response_data
             else:
-                headers["Authorization"] = f"Bearer {key}"
-
                 async with session.get(
                     f"{url}/models",
                     headers=headers,
+                    cookies=cookies,
                     ssl=AIOHTTP_CLIENT_SESSION_SSL,
                 ) as r:
                     try:
@@ -836,32 +877,9 @@ async def generate_chat_completion(
             convert_logit_bias_input_to_json(payload["logit_bias"])
         )
 
-    headers = {
-        "Content-Type": "application/json",
-        **(
-            {
-                "HTTP-Referer": "https://openwebui.com/",
-                "X-Title": "Open WebUI",
-            }
-            if "openrouter.ai" in url
-            else {}
-        ),
-        **(
-            {
-                "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
-                "X-OpenWebUI-User-Id": user.id,
-                "X-OpenWebUI-User-Email": user.email,
-                "X-OpenWebUI-User-Role": user.role,
-                **(
-                    {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
-                    if metadata and metadata.get("chat_id")
-                    else {}
-                ),
-            }
-            if ENABLE_FORWARD_USER_INFO_HEADERS
-            else {}
-        ),
-    }
+    headers, cookies = get_headers_and_cookies(
+        request, url, key, api_config, metadata, user=user
+    )
 
     if api_config.get("azure", False):
         api_version = api_config.get("api_version", "2023-03-15-preview")
@@ -871,7 +889,6 @@ async def generate_chat_completion(
         request_url = f"{request_url}/chat/completions?api-version={api_version}"
     else:
         request_url = f"{url}/chat/completions"
-        headers["Authorization"] = f"Bearer {key}"
 
     payload = json.dumps(payload)
 
@@ -890,6 +907,7 @@ async def generate_chat_completion(
             url=request_url,
             data=payload,
             headers=headers,
+            cookies=cookies,
             ssl=AIOHTTP_CLIENT_SESSION_SSL,
         )
 
@@ -951,31 +969,27 @@ async def embeddings(request: Request, form_data: dict, user):
     models = request.app.state.OPENAI_MODELS
     if model_id in models:
         idx = models[model_id]["urlIdx"]
+
     url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
     key = request.app.state.config.OPENAI_API_KEYS[idx]
+    api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
+        str(idx),
+        request.app.state.config.OPENAI_API_CONFIGS.get(url, {}),  # Legacy support
+    )
+
     r = None
     session = None
     streaming = False
+
+    headers, cookies = get_headers_and_cookies(request, url, key, api_config, user=user)
     try:
         session = aiohttp.ClientSession(trust_env=True)
         r = await session.request(
             method="POST",
             url=f"{url}/embeddings",
             data=body,
-            headers={
-                "Authorization": f"Bearer {key}",
-                "Content-Type": "application/json",
-                **(
-                    {
-                        "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
-                        "X-OpenWebUI-User-Id": user.id,
-                        "X-OpenWebUI-User-Email": user.email,
-                        "X-OpenWebUI-User-Role": user.role,
-                    }
-                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
-                    else {}
-                ),
-            },
+            headers=headers,
+            cookies=cookies,
         )
 
         if "text/event-stream" in r.headers.get("Content-Type", ""):
@@ -1037,19 +1051,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     streaming = False
 
     try:
-        headers = {
-            "Content-Type": "application/json",
-            **(
-                {
-                    "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
-                    "X-OpenWebUI-User-Id": user.id,
-                    "X-OpenWebUI-User-Email": user.email,
-                    "X-OpenWebUI-User-Role": user.role,
-                }
-                if ENABLE_FORWARD_USER_INFO_HEADERS
-                else {}
-            ),
-        }
+        headers, cookies = get_headers_and_cookies(
+            request, url, key, api_config, user=user
+        )
 
         if api_config.get("azure", False):
             api_version = api_config.get("api_version", "2023-03-15-preview")
@@ -1062,7 +1066,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 
             request_url = f"{url}/{path}?api-version={api_version}"
         else:
-            headers["Authorization"] = f"Bearer {key}"
             request_url = f"{url}/{path}"
 
         session = aiohttp.ClientSession(trust_env=True)
@@ -1071,6 +1074,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
             url=request_url,
             data=body,
             headers=headers,
+            cookies=cookies,
             ssl=AIOHTTP_CLIENT_SESSION_SSL,
         )
 

+ 8 - 11
backend/open_webui/utils/tools.py

@@ -138,12 +138,10 @@ async def get_tools(
                     elif auth_type == "oauth":
                         cookies = request.cookies
                         oauth_token = extra_params.get("__oauth_token__", None)
-                        headers["Authorization"] = (
-                            f"Bearer {oauth_token.get('access_token', '')}"
-                        )
-                    elif auth_type == "request_headers":
-                        cookies = request.cookies
-                        headers.update(dict(request.headers))
+                        if oauth_token:
+                            headers["Authorization"] = (
+                                f"Bearer {oauth_token.get('access_token', '')}"
+                            )
 
                     headers["Content-Type"] = "application/json"
 
@@ -564,9 +562,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
     return data
 
 
-async def get_tool_servers_data(
-    servers: List[Dict[str, Any]], session_token: Optional[str] = None
-) -> List[Dict[str, Any]]:
+async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
     # Prepare list of enabled servers along with their original index
     server_entries = []
     for idx, server in enumerate(servers):
@@ -582,8 +578,9 @@ async def get_tool_servers_data(
 
             if auth_type == "bearer":
                 token = server.get("key", "")
-            elif auth_type == "session":
-                token = session_token
+            elif auth_type == "none":
+                # No authentication
+                pass
 
             id = info.get("id")
             if not id:

+ 0 - 2
src/lib/components/AddConnectionModal.svelte

@@ -443,8 +443,6 @@
 							</div>
 						{/if}
 
-						<hr class=" border-gray-50 dark:border-gray-850 my-2.5 w-full" />
-
 						<div class="flex flex-col w-full">
 							<div class="mb-1 flex justify-between">
 								<div