Browse Source

refac: azure openai

Timothy Jaeryang Baek 1 month ago
parent
commit
2ab5aa4d34

+ 10 - 22
backend/open_webui/routers/openai.py

@@ -489,24 +489,10 @@ async def get_models(
                 }
 
                 if api_config.get("azure", False):
-                    headers["api-key"] = key
-
-                    api_version = api_config.get("api_version", "2023-03-15-preview")
-                    async with session.get(
-                        f"{url}/openai/deployments?api-version={api_version}",
-                        headers=headers,
-                        ssl=AIOHTTP_CLIENT_SESSION_SSL,
-                    ) as r:
-                        if r.status != 200:
-                            # Extract response error details if available
-                            error_detail = f"HTTP Error: {r.status}"
-                            res = await r.json()
-                            if "error" in res:
-                                error_detail = f"External Error: {res['error']}"
-                            raise Exception(error_detail)
-
-                        response_data = await r.json()
-                        models = response_data
+                    models = {
+                        "data": api_config.get("model_ids", []) or [],
+                        "object": "list",
+                    }
                 else:
                     headers["Authorization"] = f"Bearer {key}"
 
@@ -599,10 +585,10 @@ async def verify_connection(
 
             if api_config.get("azure", False):
                 headers["api-key"] = key
+                api_version = api_config.get("api_version", "") or "2023-03-15-preview"
 
-                api_version = api_config.get("api_version", "2023-03-15-preview")
                 async with session.get(
-                    f"{url}/openai/deployments?api-version={api_version}",
+                    url=f"{url}/openai/models?api-version={api_version}",
                     headers=headers,
                     ssl=AIOHTTP_CLIENT_SESSION_SSL,
                 ) as r:
@@ -828,7 +814,7 @@ async def generate_chat_completion(
 
     if api_config.get("azure", False):
         request_url, payload = convert_to_azure_payload(url, payload)
-        api_version = api_config.get("api_version", "2023-03-15-preview")
+        api_version = api_config.get("api_version", "") or "2023-03-15-preview"
         headers["api-key"] = key
         headers["api-version"] = api_version
         request_url = f"{request_url}/chat/completions?api-version={api_version}"
@@ -936,7 +922,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
 
         if api_config.get("azure", False):
             headers["api-key"] = key
-            headers["api-version"] = api_config.get("api_version", "2023-03-15-preview")
+            headers["api-version"] = (
+                api_config.get("api_version", "") or "2023-03-15-preview"
+            )
 
             payload = json.loads(body)
             url, payload = convert_to_azure_payload(url, payload)

+ 28 - 4
src/lib/components/AddConnectionModal.svelte

@@ -33,7 +33,7 @@
 	let connectionType = 'external';
 	let azure = false;
 	$: azure =
-		(url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com')) && !direct
+		(url.includes('azure.com') || url.includes('cognitive.microsoft.com')) && !direct
 			? true
 			: false;
 
@@ -106,6 +106,28 @@
 			return;
 		}
 
+		if (azure) {
+			if (!apiVersion) {
+				loading = false;
+
+				toast.error('API Version is required');
+				return;
+			}
+
+			if (!key) {
+				loading = false;
+
+				toast.error('Key is required');
+				return;
+			}
+
+			if (modelIds.length === 0) {
+				loading = false;
+				toast.error('Deployment names are required');
+				return;
+			}
+		}
+
 		// remove trailing slash from url
 		url = url.replace(/\/$/, '');
 
@@ -149,6 +171,7 @@
 			} else {
 				connectionType = connection.config?.connection_type ?? 'external';
 				azure = connection.config?.azure ?? false;
+				apiVersion = connection.config?.api_version ?? '';
 			}
 		}
 	};
@@ -382,9 +405,10 @@
 											url: url
 										})}
 									{:else if azure}
-										{$i18n.t('Leave empty to include all models from "{{url}}" endpoint', {
+										{$i18n.t('Deployment names are required for Azure OpenAI.')}
+										<!-- {$i18n.t('Leave empty to include all models from "{{url}}" endpoint', {
 											url: `${url}/openai/deployments`
-										})}
+										})} -->
 									{:else}
 										{$i18n.t('Leave empty to include all models from "{{url}}/models" endpoint', {
 											url: url
@@ -394,7 +418,7 @@
 							{/if}
 						</div>
 
-						<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
+						<hr class=" border-gray-100 dark:border-gray-700/10 my-1.5 w-full" />
 
 						<div class="flex items-center">
 							<input

+ 1 - 1
src/lib/components/admin/Settings/Connections/OpenAIConnection.svelte

@@ -62,7 +62,7 @@
 				class="absolute top-0 bottom-0 left-0 right-0 opacity-60 bg-white dark:bg-gray-900 z-10"
 			></div>
 		{/if}
-		<div class="flex w-full">
+		<div class="flex w-full gap-2">
 			<div class="flex-1 relative">
 				<input
 					class=" outline-hidden w-full bg-transparent {pipeline ? 'pr-8' : ''}"

+ 1 - 1
src/lib/components/chat/Settings/Connections/Connection.svelte

@@ -62,7 +62,7 @@
 				class="absolute top-0 bottom-0 left-0 right-0 opacity-60 bg-white dark:bg-gray-900 z-10"
 			></div>
 		{/if}
-		<div class="flex w-full">
+		<div class="flex w-full gap-2">
 			<div class="flex-1 relative">
 				<input
 					class=" outline-hidden w-full bg-transparent {pipeline ? 'pr-8' : ''}"