Переглянути джерело

feat: azure openai support

Timothy Jaeryang Baek 4 місяців тому
батько
коміт
caeb822cdc

+ 130 - 75
backend/open_webui/routers/openai.py

@@ -463,60 +463,88 @@ async def get_models(
         url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
         key = request.app.state.config.OPENAI_API_KEYS[url_idx]
 
+        api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
+            str(url_idx),
+            request.app.state.config.OPENAI_API_CONFIGS.get(url, {}),  # Legacy support
+        )
+
         r = None
         async with aiohttp.ClientSession(
             trust_env=True,
             timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
         ) as session:
             try:
-                async with session.get(
-                    f"{url}/models",
-                    headers={
-                        "Authorization": f"Bearer {key}",
-                        "Content-Type": "application/json",
-                        **(
-                            {
-                                "X-OpenWebUI-User-Name": user.name,
-                                "X-OpenWebUI-User-Id": user.id,
-                                "X-OpenWebUI-User-Email": user.email,
-                                "X-OpenWebUI-User-Role": user.role,
-                            }
-                            if ENABLE_FORWARD_USER_INFO_HEADERS
-                            else {}
-                        ),
-                    },
-                    ssl=AIOHTTP_CLIENT_SESSION_SSL,
-                ) as r:
-                    if r.status != 200:
-                        # Extract response error details if available
-                        error_detail = f"HTTP Error: {r.status}"
-                        res = await r.json()
-                        if "error" in res:
-                            error_detail = f"External Error: {res['error']}"
-                        raise Exception(error_detail)
-
-                    response_data = await r.json()
-
-                    # Check if we're calling OpenAI API based on the URL
-                    if "api.openai.com" in url:
-                        # Filter models according to the specified conditions
-                        response_data["data"] = [
-                            model
-                            for model in response_data.get("data", [])
-                            if not any(
-                                name in model["id"]
-                                for name in [
-                                    "babbage",
-                                    "dall-e",
-                                    "davinci",
-                                    "embedding",
-                                    "tts",
-                                    "whisper",
-                                ]
-                            )
-                        ]
-
-                    models = response_data
+                headers = {
+                    "Content-Type": "application/json",
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS
+                        else {}
+                    ),
+                }
+
+                if api_config.get("azure", False):
+                    headers["api-key"] = key
+
+                    api_version = api_config.get("api_version", "2023-03-15-preview")
+                    async with session.get(
+                        f"{url}/openai/deployments?api-version={api_version}",
+                        headers=headers,
+                        ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                    ) as r:
+                        if r.status != 200:
+                            # Extract response error details if available
+                            error_detail = f"HTTP Error: {r.status}"
+                            res = await r.json()
+                            if "error" in res:
+                                error_detail = f"External Error: {res['error']}"
+                            raise Exception(error_detail)
+
+                        response_data = await r.json()
+                        models = response_data
+                else:
+                    headers["Authorization"] = f"Bearer {key}"
+
+                    async with session.get(
+                        f"{url}/models",
+                        headers=headers,
+                        ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                    ) as r:
+                        if r.status != 200:
+                            # Extract response error details if available
+                            error_detail = f"HTTP Error: {r.status}"
+                            res = await r.json()
+                            if "error" in res:
+                                error_detail = f"External Error: {res['error']}"
+                            raise Exception(error_detail)
+
+                        response_data = await r.json()
+
+                        # Check if we're calling OpenAI API based on the URL
+                        if "api.openai.com" in url:
+                            # Filter models according to the specified conditions
+                            response_data["data"] = [
+                                model
+                                for model in response_data.get("data", [])
+                                if not any(
+                                    name in model["id"]
+                                    for name in [
+                                        "babbage",
+                                        "dall-e",
+                                        "davinci",
+                                        "embedding",
+                                        "tts",
+                                        "whisper",
+                                    ]
+                                )
+                            ]
+
+                        models = response_data
             except aiohttp.ClientError as e:
                 # ClientError covers all aiohttp requests issues
                 log.exception(f"Client error: {str(e)}")
@@ -538,6 +566,8 @@ class ConnectionVerificationForm(BaseModel):
     url: str
     key: str
 
+    config: Optional[dict] = None
+
 
 @router.post("/verify")
 async def verify_connection(
@@ -546,39 +576,64 @@ async def verify_connection(
     url = form_data.url
     key = form_data.key
 
+    api_config = form_data.config or {}
+
     async with aiohttp.ClientSession(
         trust_env=True,
         timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
     ) as session:
         try:
-            async with session.get(
-                f"{url}/models",
-                headers={
-                    "Authorization": f"Bearer {key}",
-                    "Content-Type": "application/json",
-                    **(
-                        {
-                            "X-OpenWebUI-User-Name": user.name,
-                            "X-OpenWebUI-User-Id": user.id,
-                            "X-OpenWebUI-User-Email": user.email,
-                            "X-OpenWebUI-User-Role": user.role,
-                        }
-                        if ENABLE_FORWARD_USER_INFO_HEADERS
-                        else {}
-                    ),
-                },
-                ssl=AIOHTTP_CLIENT_SESSION_SSL,
-            ) as r:
-                if r.status != 200:
-                    # Extract response error details if available
-                    error_detail = f"HTTP Error: {r.status}"
-                    res = await r.json()
-                    if "error" in res:
-                        error_detail = f"External Error: {res['error']}"
-                    raise Exception(error_detail)
+            headers = {
+                "Content-Type": "application/json",
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS
+                    else {}
+                ),
+            }
+
+            if api_config.get("azure", False):
+                headers["api-key"] = key
 
-                response_data = await r.json()
-                return response_data
+                api_version = api_config.get("api_version", "2023-03-15-preview")
+                async with session.get(
+                    f"{url}/openai/deployments?api-version={api_version}",
+                    headers=headers,
+                    ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                ) as r:
+                    if r.status != 200:
+                        # Extract response error details if available
+                        error_detail = f"HTTP Error: {r.status}"
+                        res = await r.json()
+                        if "error" in res:
+                            error_detail = f"External Error: {res['error']}"
+                        raise Exception(error_detail)
+
+                    response_data = await r.json()
+                    return response_data
+            else:
+                headers["Authorization"] = f"Bearer {key}"
+
+                async with session.get(
+                    f"{url}/models",
+                    headers=headers,
+                    ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                ) as r:
+                    if r.status != 200:
+                        # Extract response error details if available
+                        error_detail = f"HTTP Error: {r.status}"
+                        res = await r.json()
+                        if "error" in res:
+                            error_detail = f"External Error: {res['error']}"
+                        raise Exception(error_detail)
+
+                    response_data = await r.json()
+                    return response_data
 
         except aiohttp.ClientError as e:
             # ClientError covers all aiohttp requests issues

+ 2 - 7
src/lib/apis/ollama/index.ts

@@ -1,10 +1,6 @@
 import { OLLAMA_API_BASE_URL } from '$lib/constants';
 
-export const verifyOllamaConnection = async (
-	token: string = '',
-	url: string = '',
-	key: string = ''
-) => {
+export const verifyOllamaConnection = async (token: string = '', connection: dict = {}) => {
 	let error = null;
 
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/verify`, {
@@ -15,8 +11,7 @@ export const verifyOllamaConnection = async (
 			'Content-Type': 'application/json'
 		},
 		body: JSON.stringify({
-			url,
-			key
+			...connection
 		})
 	})
 		.then(async (res) => {

+ 4 - 3
src/lib/apis/openai/index.ts

@@ -267,10 +267,10 @@ export const getOpenAIModels = async (token: string, urlIdx?: number) => {
 
 export const verifyOpenAIConnection = async (
 	token: string = '',
-	url: string = 'https://api.openai.com/v1',
-	key: string = '',
+	connection: dict = {},
 	direct: boolean = false
 ) => {
+	const { url, key, config } = connection;
 	if (!url) {
 		throw 'OpenAI: URL is required';
 	}
@@ -309,7 +309,8 @@ export const verifyOpenAIConnection = async (
 			},
 			body: JSON.stringify({
 				url,
-				key
+				key,
+				config
 			})
 		})
 			.then(async (res) => {

+ 42 - 26
src/lib/components/AddConnectionModal.svelte

@@ -33,7 +33,9 @@
 	let connectionType = 'external';
 	let azure = false;
 	$: azure =
-		url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com') ? true : false;
+		(url.includes('openai.azure.com') || url.includes('cognitive.microsoft.com')) && !direct
+			? true
+			: false;
 
 	let prefixId = '';
 	let enable = true;
@@ -47,7 +49,10 @@
 	let loading = false;
 
 	const verifyOllamaHandler = async () => {
-		const res = await verifyOllamaConnection(localStorage.token, url, key).catch((error) => {
+		const res = await verifyOllamaConnection(localStorage.token, {
+			url,
+			key
+		}).catch((error) => {
 			toast.error(`${error}`);
 		});
 
@@ -57,11 +62,20 @@
 	};
 
 	const verifyOpenAIHandler = async () => {
-		const res = await verifyOpenAIConnection(localStorage.token, url, key, direct).catch(
-			(error) => {
-				toast.error(`${error}`);
-			}
-		);
+		const res = await verifyOpenAIConnection(
+			localStorage.token,
+			{
+				url,
+				key,
+				config: {
+					azure: azure,
+					api_version: apiVersion
+				}
+			},
+			direct
+		).catch((error) => {
+			toast.error(`${error}`);
+		});
 
 		if (res) {
 			toast.success($i18n.t('Server connection verified'));
@@ -187,27 +201,29 @@
 					}}
 				>
 					<div class="px-1">
-						<div class="flex gap-2">
-							<div class="flex w-full justify-between items-center">
-								<div class=" text-xs text-gray-500">{$i18n.t('Connection Type')}</div>
-
-								<div class="">
-									<button
-										on:click={() => {
-											connectionType = connectionType === 'local' ? 'external' : 'local';
-										}}
-										type="button"
-										class=" text-xs text-gray-700 dark:text-gray-300"
-									>
-										{#if connectionType === 'local'}
-											{$i18n.t('Local')}
-										{:else}
-											{$i18n.t('External')}
-										{/if}
-									</button>
+						{#if !direct}
+							<div class="flex gap-2">
+								<div class="flex w-full justify-between items-center">
+									<div class=" text-xs text-gray-500">{$i18n.t('Connection Type')}</div>
+
+									<div class="">
+										<button
+											on:click={() => {
+												connectionType = connectionType === 'local' ? 'external' : 'local';
+											}}
+											type="button"
+											class=" text-xs text-gray-700 dark:text-gray-300"
+										>
+											{#if connectionType === 'local'}
+												{$i18n.t('Local')}
+											{:else}
+												{$i18n.t('External')}
+											{/if}
+										</button>
+									</div>
 								</div>
 							</div>
-						</div>
+						{/if}
 
 						<div class="flex gap-2 mt-1.5">
 							<div class="flex flex-col w-full">

+ 0 - 4
src/lib/components/AddServerModal.svelte

@@ -3,10 +3,6 @@
 	import { getContext, onMount } from 'svelte';
 	const i18n = getContext('i18n');
 
-	import { models } from '$lib/stores';
-	import { verifyOpenAIConnection } from '$lib/apis/openai';
-	import { verifyOllamaConnection } from '$lib/apis/ollama';
-
 	import Modal from '$lib/components/common/Modal.svelte';
 	import Plus from '$lib/components/icons/Plus.svelte';
 	import Minus from '$lib/components/icons/Minus.svelte';