|
@@ -9,6 +9,8 @@ from aiocache import cached
|
|
|
import requests
|
|
|
from urllib.parse import quote
|
|
|
|
|
|
+from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
|
+
|
|
|
from fastapi import Depends, HTTPException, Request, APIRouter
|
|
|
from fastapi.responses import (
|
|
|
FileResponse,
|
|
@@ -182,12 +184,30 @@ def get_headers_and_cookies(
|
|
|
if oauth_token:
|
|
|
token = f"{oauth_token.get('access_token', '')}"
|
|
|
|
|
|
+ elif auth_type in ("azure_ad", "azure_entra_id"):
|
|
|
+ token = get_azure_entra_id_access_token()
|
|
|
+
|
|
|
if token:
|
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
|
|
|
|
return headers, cookies
|
|
|
|
|
|
|
|
|
+def get_azure_entra_id_access_token():
|
|
|
+ """
|
|
|
+ Get Azure access token using DefaultAzureCredential for Azure OpenAI.
|
|
|
+ Returns the token string or None if authentication fails.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ token_provider = get_bearer_token_provider(
|
|
|
+ DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
|
|
+ )
|
|
|
+ return token_provider()
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error getting Azure access token: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
##########################################
|
|
|
#
|
|
|
# API routes
|
|
@@ -641,9 +661,12 @@ async def verify_connection(
|
|
|
)
|
|
|
|
|
|
if api_config.get("azure", False):
|
|
|
- headers["api-key"] = key
|
|
|
- api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
|
|
+ # Only set api-key header if not using Azure Entra ID authentication
|
|
|
+ auth_type = api_config.get("auth_type", "bearer")
|
|
|
+ if auth_type not in ("azure_ad", "azure_entra_id"):
|
|
|
+ headers["api-key"] = key
|
|
|
|
|
|
+ api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
|
|
async with session.get(
|
|
|
url=f"{url}/openai/models?api-version={api_version}",
|
|
|
headers=headers,
|
|
@@ -885,7 +908,12 @@ async def generate_chat_completion(
|
|
|
if api_config.get("azure", False):
|
|
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
|
|
request_url, payload = convert_to_azure_payload(url, payload, api_version)
|
|
|
- headers["api-key"] = key
|
|
|
+
|
|
|
+ # Only set api-key header if not using Azure Entra ID authentication
|
|
|
+ auth_type = api_config.get("auth_type", "bearer")
|
|
|
+ if auth_type not in ("azure_ad", "azure_entra_id"):
|
|
|
+ headers["api-key"] = key
|
|
|
+
|
|
|
headers["api-version"] = api_version
|
|
|
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
|
|
else:
|
|
@@ -1058,7 +1086,12 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
|
|
|
|
if api_config.get("azure", False):
|
|
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
|
|
- headers["api-key"] = key
|
|
|
+
|
|
|
+ # Only set api-key header if not using Azure Entra ID authentication
|
|
|
+ auth_type = api_config.get("auth_type", "bearer")
|
|
|
+ if auth_type not in ("azure_ad", "azure_entra_id"):
|
|
|
+ headers["api-key"] = key
|
|
|
+
|
|
|
headers["api-version"] = api_version
|
|
|
|
|
|
payload = json.loads(body)
|