Browse Source

refac: PLEASE follow existing convention

Timothy Jaeryang Baek 4 months ago
parent
commit
e1e2c096e2

+ 4 - 9
backend/open_webui/config.py

@@ -2194,15 +2194,10 @@ RAG_AZURE_OPENAI_API_KEY = PersistentConfig(
     "rag.azure_openai.api_key",
     os.getenv("RAG_AZURE_OPENAI_API_KEY", ""),
 )
-RAG_AZURE_OPENAI_DEPLOYMENT = PersistentConfig(
-    "RAG_AZURE_OPENAI_DEPLOYMENT",
-    "rag.azure_openai.deployment",
-    os.getenv("RAG_AZURE_OPENAI_DEPLOYMENT", ""),
-)
-RAG_AZURE_OPENAI_VERSION = PersistentConfig(
-    "RAG_AZURE_OPENAI_VERSION",
-    "rag.azure_openai.version",
-    os.getenv("RAG_AZURE_OPENAI_VERSION", ""),
+RAG_AZURE_OPENAI_API_VERSION = PersistentConfig(
+    "RAG_AZURE_OPENAI_API_VERSION",
+    "rag.azure_openai.api_version",
+    os.getenv("RAG_AZURE_OPENAI_API_VERSION", ""),
 )
 
 RAG_OLLAMA_BASE_URL = PersistentConfig(

+ 4 - 11
backend/open_webui/main.py

@@ -209,8 +209,7 @@ from open_webui.config import (
     RAG_OPENAI_API_KEY,
     RAG_AZURE_OPENAI_BASE_URL,
     RAG_AZURE_OPENAI_API_KEY,
-    RAG_AZURE_OPENAI_DEPLOYMENT,
-    RAG_AZURE_OPENAI_VERSION,
+    RAG_AZURE_OPENAI_API_VERSION,
     RAG_OLLAMA_BASE_URL,
     RAG_OLLAMA_API_KEY,
     CHUNK_OVERLAP,
@@ -723,8 +722,7 @@ app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
 app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
 app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
-app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = RAG_AZURE_OPENAI_DEPLOYMENT
-app.state.config.RAG_AZURE_OPENAI_VERSION = RAG_AZURE_OPENAI_VERSION
+app.state.config.RAG_AZURE_OPENAI_API_VERSION = RAG_AZURE_OPENAI_API_VERSION
 
 app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
 app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
@@ -836,13 +834,8 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
         )
     ),
     app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-    (
-        app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
-        if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
-        else None
-    ),
-    (
-        app.state.config.RAG_AZURE_OPENAI_VERSION
+    azure_api_version=(
+        app.state.config.RAG_AZURE_OPENAI_API_VERSION
         if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
         else None
     ),

+ 11 - 15
backend/open_webui/retrieval/utils.py

@@ -401,8 +401,7 @@ def get_embedding_function(
     url,
     key,
     embedding_batch_size,
-    deployment=None,
-    version=None,
+    azure_api_version=None,
 ):
     if embedding_engine == "":
         return lambda query, prefix=None, user=None: embedding_function.encode(
@@ -417,8 +416,7 @@ def get_embedding_function(
             url=url,
             key=key,
             user=user,
-            deployment=deployment,
-            version=version,
+            azure_api_version=azure_api_version,
         )
 
         def generate_multiple(query, prefix, user, func):
@@ -703,24 +701,23 @@ def generate_openai_batch_embeddings(
 
 
 def generate_azure_openai_batch_embeddings(
-    deployment: str,
+    model: str,
     texts: list[str],
     url: str,
     key: str = "",
-    model: str = "",
     version: str = "",
     prefix: str = None,
     user: UserModel = None,
 ) -> Optional[list[list[float]]]:
     try:
         log.debug(
-            f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
+            f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
         )
-        json_data = {"input": texts, "model": model}
+        json_data = {"input": texts}
         if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
 
-        url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
+        url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
 
         for _ in range(5):
             r = requests.post(
@@ -855,27 +852,26 @@ def generate_embeddings(
             )
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "azure_openai":
-        deployment = kwargs.get("deployment", "")
-        version = kwargs.get("version", "")
+        azure_api_version = kwargs.get("azure_api_version", "")
         if isinstance(text, list):
             embeddings = generate_azure_openai_batch_embeddings(
-                deployment,
+                model,
                 text,
                 url,
                 key,
                 model,
-                version,
+                azure_api_version,
                 prefix,
                 user,
             )
         else:
             embeddings = generate_azure_openai_batch_embeddings(
-                deployment,
+                model,
                 [text],
                 url,
                 key,
                 model,
-                version,
+                azure_api_version,
                 prefix,
                 user,
             )

+ 12 - 24
backend/open_webui/routers/retrieval.py

@@ -242,8 +242,7 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
         "azure_openai_config": {
             "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
             "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
-            "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
-            "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
+            "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
         },
     }
 
@@ -261,7 +260,6 @@ class OllamaConfigForm(BaseModel):
 class AzureOpenAIConfigForm(BaseModel):
     url: str
     key: str
-    deployment: str
     version: str
 
 
@@ -285,7 +283,11 @@ async def update_embedding_config(
         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
 
-        if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai", "azure_openai"]:
+        if request.app.state.config.RAG_EMBEDDING_ENGINE in [
+            "ollama",
+            "openai",
+            "azure_openai",
+        ]:
             if form_data.openai_config is not None:
                 request.app.state.config.RAG_OPENAI_API_BASE_URL = (
                     form_data.openai_config.url
@@ -309,10 +311,7 @@ async def update_embedding_config(
                 request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
                     form_data.azure_openai_config.key
                 )
-                request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT = (
-                    form_data.azure_openai_config.deployment
-                )
-                request.app.state.config.RAG_AZURE_OPENAI_VERSION = (
+                request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
                     form_data.azure_openai_config.version
                 )
 
@@ -348,13 +347,8 @@ async def update_embedding_config(
                 )
             ),
             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-            (
-                request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
-                if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
-                else None
-            ),
-            (
-                request.app.state.config.RAG_AZURE_OPENAI_VERSION
+            azure_api_version=(
+                request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
                 else None
             ),
@@ -376,8 +370,7 @@ async def update_embedding_config(
             "azure_openai_config": {
                 "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
                 "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
-                "deployment": request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT,
-                "version": request.app.state.config.RAG_AZURE_OPENAI_VERSION,
+                "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
             },
         }
     except Exception as e:
@@ -1197,13 +1190,8 @@ def save_docs_to_vector_db(
                 )
             ),
             request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
-            (
-                request.app.state.config.RAG_AZURE_OPENAI_DEPLOYMENT
-                if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
-                else None
-            ),
-            (
-                request.app.state.config.RAG_AZURE_OPENAI_VERSION
+            azure_api_version=(
+                request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
                 if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
                 else None
             ),

+ 10 - 11
src/lib/apis/retrieval/index.ts

@@ -180,23 +180,22 @@ export const getEmbeddingConfig = async (token: string) => {
 };
 
 type OpenAIConfigForm = {
-        key: string;
-        url: string;
+	key: string;
+	url: string;
 };
 
 type AzureOpenAIConfigForm = {
-        key: string;
-        url: string;
-        deployment: string;
-        version: string;
+	key: string;
+	url: string;
+	version: string;
 };
 
 type EmbeddingModelUpdateForm = {
-        openai_config?: OpenAIConfigForm;
-        azure_openai_config?: AzureOpenAIConfigForm;
-        embedding_engine: string;
-        embedding_model: string;
-        embedding_batch_size?: number;
+	openai_config?: OpenAIConfigForm;
+	azure_openai_config?: AzureOpenAIConfigForm;
+	embedding_engine: string;
+	embedding_model: string;
+	embedding_batch_size?: number;
 };
 
 export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {

+ 78 - 87
src/lib/components/admin/Settings/Documents.svelte

@@ -43,13 +43,12 @@
 	let embeddingBatchSize = 1;
 	let rerankingModel = '';
 
-        let OpenAIUrl = '';
-        let OpenAIKey = '';
+	let OpenAIUrl = '';
+	let OpenAIKey = '';
 
-        let AzureOpenAIUrl = '';
-        let AzureOpenAIKey = '';
-        let AzureOpenAIDeployment = '';
-        let AzureOpenAIVersion = '';
+	let AzureOpenAIUrl = '';
+	let AzureOpenAIKey = '';
+	let AzureOpenAIVersion = '';
 
 	let OllamaUrl = '';
 	let OllamaKey = '';
@@ -91,40 +90,39 @@
 			return;
 		}
 
-                if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) {
-                        toast.error($i18n.t('OpenAI URL/Key required.'));
-                        return;
-                }
-                if (
-                        embeddingEngine === 'azure_openai' &&
-                        (AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIDeployment === '' || AzureOpenAIVersion === '')
-                ) {
-                        toast.error($i18n.t('OpenAI URL/Key required.'));
-                        return;
-                }
+		if (embeddingEngine === 'openai' && (OpenAIKey === '' || OpenAIUrl === '')) {
+			toast.error($i18n.t('OpenAI URL/Key required.'));
+			return;
+		}
+		if (
+			embeddingEngine === 'azure_openai' &&
+			(AzureOpenAIKey === '' || AzureOpenAIUrl === '' || AzureOpenAIVersion === '')
+		) {
+			toast.error($i18n.t('OpenAI URL/Key required.'));
+			return;
+		}
 
 		console.debug('Update embedding model attempt:', embeddingModel);
 
 		updateEmbeddingModelLoading = true;
-                const res = await updateEmbeddingConfig(localStorage.token, {
-                        embedding_engine: embeddingEngine,
-                        embedding_model: embeddingModel,
-                        embedding_batch_size: embeddingBatchSize,
-                        ollama_config: {
-                                key: OllamaKey,
-                                url: OllamaUrl
-                        },
-                        openai_config: {
-                                key: OpenAIKey,
-                                url: OpenAIUrl
-                        },
-                        azure_openai_config: {
-                                key: AzureOpenAIKey,
-                                url: AzureOpenAIUrl,
-                                deployment: AzureOpenAIDeployment,
-                                version: AzureOpenAIVersion
-                        }
-                }).catch(async (error) => {
+		const res = await updateEmbeddingConfig(localStorage.token, {
+			embedding_engine: embeddingEngine,
+			embedding_model: embeddingModel,
+			embedding_batch_size: embeddingBatchSize,
+			ollama_config: {
+				key: OllamaKey,
+				url: OllamaUrl
+			},
+			openai_config: {
+				key: OpenAIKey,
+				url: OpenAIUrl
+			},
+			azure_openai_config: {
+				key: AzureOpenAIKey,
+				url: AzureOpenAIUrl,
+				version: AzureOpenAIVersion
+			}
+		}).catch(async (error) => {
 			toast.error(`${error}`);
 			await setEmbeddingConfig();
 			return null;
@@ -218,18 +216,17 @@
 			embeddingModel = embeddingConfig.embedding_model;
 			embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
 
-                        OpenAIKey = embeddingConfig.openai_config.key;
-                        OpenAIUrl = embeddingConfig.openai_config.url;
+			OpenAIKey = embeddingConfig.openai_config.key;
+			OpenAIUrl = embeddingConfig.openai_config.url;
 
-                        OllamaKey = embeddingConfig.ollama_config.key;
-                        OllamaUrl = embeddingConfig.ollama_config.url;
+			OllamaKey = embeddingConfig.ollama_config.key;
+			OllamaUrl = embeddingConfig.ollama_config.url;
 
-                        AzureOpenAIKey = embeddingConfig.azure_openai_config.key;
-                        AzureOpenAIUrl = embeddingConfig.azure_openai_config.url;
-                        AzureOpenAIDeployment = embeddingConfig.azure_openai_config.deployment;
-                        AzureOpenAIVersion = embeddingConfig.azure_openai_config.version;
-                }
-        };
+			AzureOpenAIKey = embeddingConfig.azure_openai_config.key;
+			AzureOpenAIUrl = embeddingConfig.azure_openai_config.url;
+			AzureOpenAIVersion = embeddingConfig.azure_openai_config.version;
+		}
+	};
 	onMount(async () => {
 		await setEmbeddingConfig();
 
@@ -626,26 +623,26 @@
 										bind:value={embeddingEngine}
 										placeholder="Select an embedding model engine"
 										on:change={(e) => {
-                                                                        if (e.target.value === 'ollama') {
-                                                                               embeddingModel = '';
-                                                                       } else if (e.target.value === 'openai') {
-                                                                               embeddingModel = 'text-embedding-3-small';
-                                                                       } else if (e.target.value === 'azure_openai') {
-                                                                               embeddingModel = 'text-embedding-3-small';
-                                                                       } else if (e.target.value === '') {
-                                                                               embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
-                                                                       }
+											if (e.target.value === 'ollama') {
+												embeddingModel = '';
+											} else if (e.target.value === 'openai') {
+												embeddingModel = 'text-embedding-3-small';
+											} else if (e.target.value === 'azure_openai') {
+												embeddingModel = 'text-embedding-3-small';
+											} else if (e.target.value === '') {
+												embeddingModel = 'sentence-transformers/all-MiniLM-L6-v2';
+											}
 										}}
 									>
 										<option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
 										<option value="ollama">{$i18n.t('Ollama')}</option>
-                                                                               <option value="openai">{$i18n.t('OpenAI')}</option>
-                                                                               <option value="azure_openai">Azure OpenAI</option>
+										<option value="openai">{$i18n.t('OpenAI')}</option>
+										<option value="azure_openai">Azure OpenAI</option>
 									</select>
 								</div>
 							</div>
 
-                                                        {#if embeddingEngine === 'openai'}
+							{#if embeddingEngine === 'openai'}
 								<div class="my-0.5 flex gap-2 pr-2">
 									<input
 										class="flex-1 w-full text-sm bg-transparent outline-hidden"
@@ -656,7 +653,7 @@
 
 									<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
 								</div>
-                                                        {:else if embeddingEngine === 'ollama'}
+							{:else if embeddingEngine === 'ollama'}
 								<div class="my-0.5 flex gap-2 pr-2">
 									<input
 										class="flex-1 w-full text-sm bg-transparent outline-hidden"
@@ -671,33 +668,27 @@
 										required={false}
 									/>
 								</div>
-                                                        {:else if embeddingEngine === 'azure_openai'}
-                                                                <div class="my-0.5 flex flex-col gap-2 pr-2 w-full">
-                                                                        <div class="flex gap-2">
-                                                                                <input
-                                                                                        class="flex-1 w-full text-sm bg-transparent outline-hidden"
-                                                                                        placeholder={$i18n.t('API Base URL')}
-                                                                                        bind:value={AzureOpenAIUrl}
-                                                                                        required
-                                                                                />
-                                                                                <SensitiveInput placeholder={$i18n.t('API Key')} bind:value={AzureOpenAIKey} />
-                                                                        </div>
-                                                                        <div class="flex gap-2">
-                                                                                <input
-                                                                                        class="flex-1 w-full text-sm bg-transparent outline-hidden"
-                                                                                        placeholder="Deployment"
-                                                                                        bind:value={AzureOpenAIDeployment}
-                                                                                        required
-                                                                                />
-                                                                                <input
-                                                                                        class="flex-1 w-full text-sm bg-transparent outline-hidden"
-                                                                                        placeholder="Version"
-                                                                                        bind:value={AzureOpenAIVersion}
-                                                                                        required
-                                                                                />
-                                                                        </div>
-                                                                </div>
-                                                        {/if}
+							{:else if embeddingEngine === 'azure_openai'}
+								<div class="my-0.5 flex flex-col gap-2 pr-2 w-full">
+									<div class="flex gap-2">
+										<input
+											class="flex-1 w-full text-sm bg-transparent outline-hidden"
+											placeholder={$i18n.t('API Base URL')}
+											bind:value={AzureOpenAIUrl}
+											required
+										/>
+										<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={AzureOpenAIKey} />
+									</div>
+									<div class="flex gap-2">
+										<input
+											class="flex-1 w-full text-sm bg-transparent outline-hidden"
+											placeholder="Version"
+											bind:value={AzureOpenAIVersion}
+											required
+										/>
+									</div>
+								</div>
+							{/if}
 						</div>
 
 						<div class="  mb-2.5 flex flex-col w-full">
@@ -793,7 +784,7 @@
 							</div>
 						</div>
 
-                                                {#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'}
+						{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai' || embeddingEngine === 'azure_openai'}
 							<div class="  mb-2.5 flex w-full justify-between">
 								<div class=" self-center text-xs font-medium">
 									{$i18n.t('Embedding Batch Size')}