Bläddra i källkod

Add Azure OpenAI embedding support

Derek Wischusen 4 månader sedan
förälder
incheckning
42be1f956a

+ 21 - 0
backend/open_webui/config.py

@@ -2124,6 +2124,27 @@ RAG_OPENAI_API_KEY = PersistentConfig(
     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
 )
 
+RAG_AZURE_OPENAI_BASE_URL = PersistentConfig(
+    "RAG_AZURE_OPENAI_BASE_URL",
+    "rag.azure_openai.base_url",
+    os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""),
+)
+RAG_AZURE_OPENAI_API_KEY = PersistentConfig(
+    "RAG_AZURE_OPENAI_API_KEY",
+    "rag.azure_openai.api_key",
+    os.getenv("RAG_AZURE_OPENAI_API_KEY", ""),
+)
+RAG_AZURE_OPENAI_DEPLOYMENT = PersistentConfig(
+    "RAG_AZURE_OPENAI_DEPLOYMENT",
+    "rag.azure_openai.deployment",
+    os.getenv("RAG_AZURE_OPENAI_DEPLOYMENT", ""),
+)
+RAG_AZURE_OPENAI_VERSION = PersistentConfig(
+    "RAG_AZURE_OPENAI_VERSION",
+    "rag.azure_openai.version",
+    os.getenv("RAG_AZURE_OPENAI_VERSION", ""),
+)
+
 RAG_OLLAMA_BASE_URL = PersistentConfig(
     "RAG_OLLAMA_BASE_URL",
     "rag.ollama.url",

+ 29 - 2
backend/open_webui/main.py

@@ -202,6 +202,10 @@ from open_webui.config import (
     RAG_FILE_MAX_SIZE,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_KEY,
+    RAG_AZURE_OPENAI_BASE_URL,
+    RAG_AZURE_OPENAI_API_KEY,
+    RAG_AZURE_OPENAI_DEPLOYMENT,
+    RAG_AZURE_OPENAI_VERSION,
     RAG_OLLAMA_BASE_URL,
     RAG_OLLAMA_API_KEY,
     CHUNK_OVERLAP,
@@ -688,6 +692,11 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
+app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
+app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
+app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = RAG_AZURE_OPENAI_DEPLOYMENT
+app.state.config.RAG_AZURE_OPENAI_VERSION = RAG_AZURE_OPENAI_VERSION
+
 app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
 app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
 
@@ -781,14 +790,32 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
     (
         app.state.config.RAG_OPENAI_API_BASE_URL
         if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-        else app.state.config.RAG_OLLAMA_BASE_URL
+        else (
+            app.state.config.RAG_OLLAMA_BASE_URL
+            if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
+            else app.state.config.RAG_AZURE_OPENAI_BASE_URL
+        )
     ),
     (
         app.state.config.RAG_OPENAI_API_KEY
         if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-        else app.state.config.RAG_OLLAMA_API_KEY
+        else (
+            app.state.config.RAG_OLLAMA_API_KEY
+            if app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
+            else app.state.config.RAG_AZURE_OPENAI_API_KEY
+        )
     ),
     app.state.config.RAG_EMBEDDING_BATCH_SIZE,
+    (
+        app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
+        if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
+        else None
+    ),
+    (
+        app.state.config.RAG_AZURE_OPENAI_VERSION
+        if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
+        else None
+    ),
 )
 
 ########################################

+ 87 - 1
backend/open_webui/retrieval/utils.py

@@ -5,6 +5,7 @@ from typing import Optional, Union
 import requests
 import hashlib
 from concurrent.futures import ThreadPoolExecutor
+import time
 
 from huggingface_hub import snapshot_download
 from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@@ -386,12 +387,14 @@ def get_embedding_function(
     url,
     key,
     embedding_batch_size,
+    deployment=None,
+    version=None,
 ):
     if embedding_engine == "":
         return lambda query, prefix=None, user=None: embedding_function.encode(
             query, **({"prompt": prefix} if prefix else {})
         ).tolist()
-    elif embedding_engine in ["ollama", "openai"]:
+    elif embedding_engine in ["ollama", "openai", "azure_openai"]:
         func = lambda query, prefix=None, user=None: generate_embeddings(
             engine=embedding_engine,
             model=embedding_model,
@@ -400,6 +403,8 @@ def get_embedding_function(
             url=url,
             key=key,
             user=user,
+            deployment=deployment,
+            version=version,
         )
 
         def generate_multiple(query, prefix, user, func):
@@ -681,6 +686,61 @@ def generate_openai_batch_embeddings(
         return None
 
 
+def generate_azure_openai_batch_embeddings(
+    deployment: str,
+    texts: list[str],
+    url: str,
+    key: str = "",
+    model: str = "",
+    version: str = "",
+    prefix: str = None,
+    user: UserModel = None,
+) -> Optional[list[list[float]]]:
+    try:
+        log.debug(
+            f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
+        )
+        json_data = {"input": texts, "model": model}
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
+            json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
+
+        url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
+
+        for _ in range(5):
+            r = requests.post(
+                url,
+                headers={
+                    "Content-Type": "application/json",
+                    "api-key": key,
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                        else {}
+                    ),
+                },
+                json=json_data,
+            )
+            if r.status_code == 429:
+                retry = float(r.headers.get("Retry-After", "1"))
+                time.sleep(retry)
+                continue
+            r.raise_for_status()
+            data = r.json()
+            if "data" in data:
+                return [elem["embedding"] for elem in data["data"]]
+            else:
+                raise Exception("Something went wrong :/")
+        return None
+    except Exception as e:
+        log.exception(f"Error generating azure openai batch embeddings: {e}")
+        return None
+
+
 def generate_ollama_batch_embeddings(
     model: str,
     texts: list[str],
@@ -778,6 +838,32 @@ def generate_embeddings(
                 model, [text], url, key, prefix, user
             )
         return embeddings[0] if isinstance(text, str) else embeddings
+    elif engine == "azure_openai":
+        deployment = kwargs.get("deployment", "")
+        version = kwargs.get("version", "")
+        if isinstance(text, list):
+            embeddings = generate_azure_openai_batch_embeddings(
+                deployment,
+                text,
+                url,
+                key,
+                model,
+                version,
+                prefix,
+                user,
+            )
+        else:
+            embeddings = generate_azure_openai_batch_embeddings(
+                deployment,
+                [text],
+                url,
+                key,
+                model,
+                version,
+                prefix,
+                user,
+            )
+        return embeddings[0] if isinstance(text, str) else embeddings
 
 
 import operator

+ 75 - 5
backend/open_webui/routers/retrieval.py

@@ -239,6 +239,12 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
             "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
             "key": request.app.state.config.RAG_OLLAMA_API_KEY,
         },
+        "azure_openai_config": {
+            "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
+            "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
+            "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
+            "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
+        },
     }
 
 
@@ -252,9 +258,17 @@ class OllamaConfigForm(BaseModel):
     key: str
 
 
+class AzureOpenAIConfigForm(BaseModel):
+    url: str
+    key: str
+    deployment: str
+    version: str
+
+
 class EmbeddingModelUpdateForm(BaseModel):
     openai_config: Optional[OpenAIConfigForm] = None
     ollama_config: Optional[OllamaConfigForm] = None
+    azure_openai_config: Optional[AzureOpenAIConfigForm] = None
     embedding_engine: str
     embedding_model: str
     embedding_batch_size: Optional[int] = 1
@@ -271,7 +285,7 @@ async def update_embedding_config(
         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
 
-        if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
+        if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]:
             if form_data.openai_config is not None:
                 request.app.state.config.RAG_OPENAI_API_BASE_URL = (
                     form_data.openai_config.url
@@ -288,6 +302,20 @@ async def update_embedding_config(
                     form_data.ollama_config.key
                 )
 
+            if form_data.azure_openai_config is not None:
+                request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
+                    form_data.azure_openai_config.url
+                )
+                request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
+                    form_data.azure_openai_config.key
+                )
+                request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = (
+                    form_data.azure_openai_config.deployment
+                )
+                request.app.state.config.RAG_AZURE_OPENAI_VERSION = (
+                    form_data.azure_openai_config.version
+                )
+
             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
                 form_data.embedding_batch_size
             )
@@ -304,14 +332,32 @@ async def update_embedding_config(
             (
                 request.app.state.config.RAG_OPENAI_API_BASE_URL
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else request.app.state.config.RAG_OLLAMA_BASE_URL
+                else (
+                    request.app.state.config.RAG_OLLAMA_BASE_URL
+                    if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
+                    else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
+                )
             ),
             (
                 request.app.state.config.RAG_OPENAI_API_KEY
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else request.app.state.config.RAG_OLLAMA_API_KEY
+                else (
+                    request.app.state.config.RAG_OLLAMA_API_KEY
+                    if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
+                    else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
+                )
             ),
             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
+            (
+                request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
+                if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
+                else None
+            ),
+            (
+                request.app.state.config.RAG_AZURE_OPENAI_VERSION
+                if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
+                else None
+            ),
         )
 
         return {
@@ -327,6 +373,12 @@ async def update_embedding_config(
                 "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
                 "key": request.app.state.config.RAG_OLLAMA_API_KEY,
             },
+            "azure_openai_config": {
+                "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
+                "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
+                "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
+                "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
+            },
         }
     except Exception as e:
         log.exception(f"Problem updating embedding model: {e}")
@@ -1043,14 +1095,32 @@ def save_docs_to_vector_db(
             (
                 request.app.state.config.RAG_OPENAI_API_BASE_URL
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else request.app.state.config.RAG_OLLAMA_BASE_URL
+                else (
+                    request.app.state.config.RAG_OLLAMA_BASE_URL
+                    if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
+                    else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
+                )
             ),
             (
                 request.app.state.config.RAG_OPENAI_API_KEY
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
-                else request.app.state.config.RAG_OLLAMA_API_KEY
+                else (
+                    request.app.state.config.RAG_OLLAMA_API_KEY
+                    if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
+                    else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
+                )
             ),
             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
+            (
+                request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
+                if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
+                else None
+            ),
+            (
+                request.app.state.config.RAG_AZURE_OPENAI_VERSION
+                if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
+                else None
+            ),
         )
 
         embeddings = embedding_function(

+ 14 - 6
src/lib/apis/retrieval/index.ts

@@ -180,15 +180,23 @@ export const getEmbeddingConfig = async (token: string) => {
 };
 
 type OpenAIConfigForm = {
-	key: string;
-	url: string;
+        key: string;
+        url: string;
+};
+
+type AzureOpenAIConfigForm = {
+        key: string;
+        url: string;
+        deployment: string;
+        version: string;
 };
 
 type EmbeddingModelUpdateForm = {
-	openai_config?: OpenAIConfigForm;
-	embedding_engine: string;
-	embedding_model: string;
-	embedding_batch_size?: number;
+        openai_config?: OpenAIConfigForm;
+        azure_openai_config?: AzureOpenAIConfigForm;
+        embedding_engine: string;
+        embedding_model: string;
+        embedding_batch_size?: number;
 };
 
 export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {

+ 89 - 37
src/lib/components/admin/Settings/Documents.svelte

@@ -43,8 +43,13 @@
 	let embeddingBatchSize = 1;
 	let rerankingModel = '';
 
-	let OpenAIUrl = '';
-	let OpenAIKey = '';
+        let OpenAIUrl = '';
+        let OpenAIKey = '';
+
+        let AzureOpenAIUrl = '';
+        let AzureOpenAIKey = '';
+        let AzureOpenAIDeployment = '';
+        let AzureOpenAIVersion = '';
 
 	let OllamaUrl = '';
 	let OllamaKey = '';
@@ -86,27 +91,40 @@
 			return;
 		}
 
-		if ((embeddingEngine === 'openai' && OpenAIKey === '') || OpenAIUrl === '') {
-			toast.error($i18n.t('OpenAI URL/Key required.'));
-			return;
-		}
+                if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) {
+                        toast.error($i18n.t('OpenAI URL/Key required.'));
+                        return;
+                }
+                if (
+                        embeddingEngine === 'azure_openai' &&
+                        (AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIDeployment === '' || AzureOpenAIVersion === '')
+                ) {
+                        toast.error($i18n.t('OpenAI URL/Key required.'));
+                        return;
+                }
 
 		console.debug('Update embedding model attempt:', embeddingModel);
 
 		updateEmbeddingModelLoading = true;
-		const res = await updateEmbeddingConfig(localStorage.token, {
-			embedding_engine: embeddingEngine,
-			embedding_model: embeddingModel,
-			embedding_batch_size: embeddingBatchSize,
-			ollama_config: {
-				key: OllamaKey,
-				url: OllamaUrl
-			},
-			openai_config: {
-				key: OpenAIKey,
-				url: OpenAIUrl
-			}
-		}).catch(async (error) => {
+                const res = await updateEmbeddingConfig(localStorage.token, {
+                        embedding_engine: embeddingEngine,
+                        embedding_model: embeddingModel,
+                        embedding_batch_size: embeddingBatchSize,
+                        ollama_config: {
+                                key: OllamaKey,
+                                url: OllamaUrl
+                        },
+                        openai_config: {
+                                key: OpenAIKey,
+                                url: OpenAIUrl
+                        },
+                        azure_openai_config: {
+                                key: AzureOpenAIKey,
+                                url: AzureOpenAIUrl,
+                                deployment: AzureOpenAIDeployment,
+                                version: AzureOpenAIVersion
+                        }
+                }).catch(async (error) => {
 			toast.error(`${error}`);
 			await setEmbeddingConfig();
 			return null;
@@ -186,13 +204,18 @@
 			embeddingModel = embeddingConfig.embedding_model;
 			embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
 
-			OpenAIKey = embeddingConfig.openai_config.key;
-			OpenAIUrl = embeddingConfig.openai_config.url;
+                        OpenAIKey = embeddingConfig.openai_config.key;
+                        OpenAIUrl = embeddingConfig.openai_config.url;
 
-			OllamaKey = embeddingConfig.ollama_config.key;
-			OllamaUrl = embeddingConfig.ollama_config.url;
-		}
-	};
+                        OllamaKey = embeddingConfig.ollama_config.key;
+                        OllamaUrl = embeddingConfig.ollama_config.url;
+
+                        AzureOpenAIKey = embeddingConfig.azure_openai_config.key;
+                        AzureOpenAIUrl = embeddingConfig.azure_openai_config.url;
+                        AzureOpenAIDeployment = embeddingConfig.azure_openai_config.deployment;
+                        AzureOpenAIVersion = embeddingConfig.azure_openai_config.version;
+                }
+        };
 	onMount(async () => {
 		await setEmbeddingConfig();
 
@@ -457,23 +480,26 @@
 										bind:value={embeddingEngine}
 										placeholder="Select an embedding model engine"
 										on:change={(e) => {
-											if (e.target.value === 'ollama') {
-												embeddingModel = '';
-											} else if (e.target.value === 'openai') {
-												embeddingModel = 'text-embedding-3-small';
-											} else if (e.target.value === '') {
-												embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
-											}
+                                                                        if (e.target.value === 'ollama') {
+                                                                               embeddingModel = '';
+                                                                       } else if (e.target.value === 'openai') {
+                                                                               embeddingModel = 'text-embedding-3-small';
+                                                                       } else if (e.target.value === 'azure_openai') {
+                                                                               embeddingModel = 'text-embedding-3-small';
+                                                                       } else if (e.target.value === '') {
+                                                                               embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
+                                                                       }
 										}}
 									>
 										<option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
 										<option value="ollama">{$i18n.t('Ollama')}</option>
-										<option value="openai">{$i18n.t('OpenAI')}</option>
+                                                                               <option value="openai">{$i18n.t('OpenAI')}</option>
+                                                                               <option value="azure_openai">Azure OpenAI</option>
 									</select>
 								</div>
 							</div>
 
-							{#if embeddingEngine === 'openai'}
+                                                        {#if embeddingEngine === 'openai'}
 								<div class="my-0.5 flex gap-2 pr-2">
 									<input
 										class="flex-1 w-full text-sm bg-transparent outline-hidden"
@@ -484,7 +510,7 @@
 
 									<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
 								</div>
-							{:else if embeddingEngine === 'ollama'}
+                                                        {:else if embeddingEngine === 'ollama'}
 								<div class="my-0.5 flex gap-2 pr-2">
 									<input
 										class="flex-1 w-full text-sm bg-transparent outline-hidden"
@@ -499,7 +525,33 @@
 										required={false}
 									/>
 								</div>
-							{/if}
+                                                        {:else if embeddingEngine === 'azure_openai'}
+                                                                <div class="my-0.5 flex flex-col gap-2 pr-2 w-full">
+                                                                        <div class="flex gap-2">
+                                                                                <input
+                                                                                        class="flex-1 w-full text-sm bg-transparent outline-hidden"
+                                                                                        placeholder={$i18n.t('API Base URL')}
+                                                                                        bind:value={AzureOpenAIUrl}
+                                                                                        required
+                                                                                />
+                                                                                <SensitiveInput placeholder={$i18n.t('API Key')} bind:value={AzureOpenAIKey} />
+                                                                        </div>
+                                                                        <div class="flex gap-2">
+                                                                                <input
+                                                                                        class="flex-1 w-full text-sm bg-transparent outline-hidden"
+                                                                                        placeholder="Deployment"
+                                                                                        bind:value={AzureOpenAIDeployment}
+                                                                                        required
+                                                                                />
+                                                                                <input
+                                                                                        class="flex-1 w-full text-sm bg-transparent outline-hidden"
+                                                                                        placeholder="Version"
+                                                                                        bind:value={AzureOpenAIVersion}
+                                                                                        required
+                                                                                />
+                                                                        </div>
+                                                                </div>
+                                                        {/if}
 						</div>
 
 						<div class="  mb-2.5 flex flex-col w-full">
@@ -595,7 +647,7 @@
 							</div>
 						</div>
 
-						{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
+                                                {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'}
 							<div class="  mb-2.5 flex w-full justify-between">
 								<div class=" self-center text-xs font-medium">
 									{$i18n.t('Embedding Batch Size')}