Timothy Jaeryang Baek 3 месяцев назад
Родитель
Сommit
1a52585769

+ 12 - 0
backend/open_webui/config.py

@@ -923,6 +923,18 @@ except Exception:
     pass
     pass
 OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 
 
+
+####################################
+# MODEL_LIST
+####################################
+
+ENABLE_MODEL_LIST_CACHE = PersistentConfig(
+    "ENABLE_MODEL_LIST_CACHE",
+    "models.cache",
+    os.environ.get("ENABLE_MODEL_LIST_CACHE", "False").lower() == "true",
+)
+
+
 ####################################
 ####################################
 # TOOL_SERVERS
 # TOOL_SERVERS
 ####################################
 ####################################

+ 41 - 3
backend/open_webui/main.py

@@ -36,7 +36,6 @@ from fastapi import (
     applications,
     applications,
     BackgroundTasks,
     BackgroundTasks,
 )
 )
-
 from fastapi.openapi.docs import get_swagger_ui_html
 from fastapi.openapi.docs import get_swagger_ui_html
 
 
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import Response, StreamingResponse
 from starlette.responses import Response, StreamingResponse
+from starlette.datastructures import Headers
 
 
 
 
 from open_webui.utils import logger
 from open_webui.utils import logger
@@ -116,6 +116,8 @@ from open_webui.config import (
     OPENAI_API_CONFIGS,
     OPENAI_API_CONFIGS,
     # Direct Connections
     # Direct Connections
     ENABLE_DIRECT_CONNECTIONS,
     ENABLE_DIRECT_CONNECTIONS,
+    # Model list
+    ENABLE_MODEL_LIST_CACHE,
     # Thread pool size for FastAPI/AnyIO
     # Thread pool size for FastAPI/AnyIO
     THREAD_POOL_SIZE,
     THREAD_POOL_SIZE,
     # Tool Server Configs
     # Tool Server Configs
@@ -534,6 +536,27 @@ async def lifespan(app: FastAPI):
 
 
     asyncio.create_task(periodic_usage_pool_cleanup())
     asyncio.create_task(periodic_usage_pool_cleanup())
 
 
+    if app.state.config.ENABLE_MODEL_LIST_CACHE:
+        get_all_models(
+            Request(
+                # Creating a mock request object to pass to get_all_models
+                {
+                    "type": "http",
+                    "asgi.version": "3.0",
+                    "asgi.spec_version": "2.0",
+                    "method": "GET",
+                    "path": "/internal",
+                    "query_string": b"",
+                    "headers": Headers({}).raw,
+                    "client": ("127.0.0.1", 12345),
+                    "server": ("127.0.0.1", 80),
+                    "scheme": "http",
+                    "app": app,
+                }
+            ),
+            None,
+        )
+
     yield
     yield
 
 
     if hasattr(app.state, "redis_task_command_listener"):
     if hasattr(app.state, "redis_task_command_listener"):
@@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = []
 
 
 app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
 app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
 
 
+########################################
+#
+# MODEL LIST
+#
+########################################
+
+app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
+
 ########################################
 ########################################
 #
 #
 # WEBUI
 # WEBUI
@@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE:
 
 
 
 
 @app.get("/api/models")
 @app.get("/api/models")
-async def get_models(request: Request, user=Depends(get_verified_user)):
+async def get_models(
+    request: Request, refresh: bool = False, user=Depends(get_verified_user)
+):
     def get_filtered_models(models, user):
     def get_filtered_models(models, user):
         filtered_models = []
         filtered_models = []
         for model in models:
         for model in models:
@@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
 
 
         return filtered_models
         return filtered_models
 
 
-    all_models = await get_all_models(request, user=user)
+    if request.app.state.MODELS and (
+        request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
+    ):
+        all_models = list(request.app.state.MODELS.values())
+    else:
+        all_models = await get_all_models(request, user=user)
 
 
     models = []
     models = []
     for model in all_models:
     for model in all_models:

+ 12 - 7
backend/open_webui/routers/configs.py

@@ -39,32 +39,37 @@ async def export_config(user=Depends(get_admin_user)):
 
 
 
 
 ############################
 ############################
-# Direct Connections Config
+# Connections Config
 ############################
 ############################
 
 
 
 
-class DirectConnectionsConfigForm(BaseModel):
+class ConnectionsConfigForm(BaseModel):
     ENABLE_DIRECT_CONNECTIONS: bool
     ENABLE_DIRECT_CONNECTIONS: bool
+    ENABLE_MODEL_LIST_CACHE: bool
 
 
 
 
-@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+@router.get("/connections", response_model=ConnectionsConfigForm)
+async def get_connections_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+        "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
     }
     }
 
 
 
 
-@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def set_direct_connections_config(
+@router.post("/connections", response_model=ConnectionsConfigForm)
+async def set_connections_config(
     request: Request,
     request: Request,
-    form_data: DirectConnectionsConfigForm,
+    form_data: ConnectionsConfigForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
     request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
     request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
         form_data.ENABLE_DIRECT_CONNECTIONS
         form_data.ENABLE_DIRECT_CONNECTIONS
     )
     )
+    request.app.state.config.ENABLE_MODEL_LIST_CACHE = form_data.ENABLE_MODEL_LIST_CACHE
+
     return {
     return {
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+        "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
     }
     }
 
 
 
 

+ 4 - 4
src/lib/apis/configs/index.ts

@@ -58,10 +58,10 @@ export const exportConfig = async (token: string) => {
 	return res;
 	return res;
 };
 };
 
 
-export const getDirectConnectionsConfig = async (token: string) => {
+export const getConnectionsConfig = async (token: string) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
 		method: 'GET',
 		method: 'GET',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',
@@ -85,10 +85,10 @@ export const getDirectConnectionsConfig = async (token: string) => {
 	return res;
 	return res;
 };
 };
 
 
-export const setDirectConnectionsConfig = async (token: string, config: object) => {
+export const setConnectionsConfig = async (token: string, config: object) => {
 	let error = null;
 	let error = null;
 
 
-	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
 		method: 'POST',
 		method: 'POST',
 		headers: {
 		headers: {
 			'Content-Type': 'application/json',
 			'Content-Type': 'application/json',

+ 36 - 13
src/lib/components/admin/Settings/Connections.svelte

@@ -7,7 +7,7 @@
 	import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama';
 	import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama';
 	import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai';
 	import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai';
 	import { getModels as _getModels } from '$lib/apis';
 	import { getModels as _getModels } from '$lib/apis';
-	import { getDirectConnectionsConfig, setDirectConnectionsConfig } from '$lib/apis/configs';
+	import { getConnectionsConfig, setConnectionsConfig } from '$lib/apis/configs';
 
 
 	import { config, models, settings, user } from '$lib/stores';
 	import { config, models, settings, user } from '$lib/stores';
 
 
@@ -43,7 +43,7 @@
 	let ENABLE_OPENAI_API: null | boolean = null;
 	let ENABLE_OPENAI_API: null | boolean = null;
 	let ENABLE_OLLAMA_API: null | boolean = null;
 	let ENABLE_OLLAMA_API: null | boolean = null;
 
 
-	let directConnectionsConfig = null;
+	let connectionsConfig = null;
 
 
 	let pipelineUrls = {};
 	let pipelineUrls = {};
 	let showAddOpenAIConnectionModal = false;
 	let showAddOpenAIConnectionModal = false;
@@ -106,15 +106,13 @@
 		}
 		}
 	};
 	};
 
 
-	const updateDirectConnectionsHandler = async () => {
-		const res = await setDirectConnectionsConfig(localStorage.token, directConnectionsConfig).catch(
-			(error) => {
-				toast.error(`${error}`);
-			}
-		);
+	const updateConnectionsHandler = async () => {
+		const res = await setConnectionsConfig(localStorage.token, connectionsConfig).catch((error) => {
+			toast.error(`${error}`);
+		});
 
 
 		if (res) {
 		if (res) {
-			toast.success($i18n.t('Direct Connections settings updated'));
+			toast.success($i18n.t('Connections settings updated'));
 			await models.set(await getModels());
 			await models.set(await getModels());
 		}
 		}
 	};
 	};
@@ -150,7 +148,7 @@
 					openaiConfig = await getOpenAIConfig(localStorage.token);
 					openaiConfig = await getOpenAIConfig(localStorage.token);
 				})(),
 				})(),
 				(async () => {
 				(async () => {
-					directConnectionsConfig = await getDirectConnectionsConfig(localStorage.token);
+					connectionsConfig = await getConnectionsConfig(localStorage.token);
 				})()
 				})()
 			]);
 			]);
 
 
@@ -217,7 +215,7 @@
 
 
 <form class="flex flex-col h-full justify-between text-sm" on:submit|preventDefault={submitHandler}>
 <form class="flex flex-col h-full justify-between text-sm" on:submit|preventDefault={submitHandler}>
 	<div class=" overflow-y-scroll scrollbar-hidden h-full">
 	<div class=" overflow-y-scroll scrollbar-hidden h-full">
-		{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && directConnectionsConfig !== null}
+		{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && connectionsConfig !== null}
 			<div class="mb-3.5">
 			<div class="mb-3.5">
 				<div class=" mb-2.5 text-base font-medium">{$i18n.t('General')}</div>
 				<div class=" mb-2.5 text-base font-medium">{$i18n.t('General')}</div>
 
 
@@ -368,9 +366,9 @@
 						<div class="flex items-center">
 						<div class="flex items-center">
 							<div class="">
 							<div class="">
 								<Switch
 								<Switch
-									bind:state={directConnectionsConfig.ENABLE_DIRECT_CONNECTIONS}
+									bind:state={connectionsConfig.ENABLE_DIRECT_CONNECTIONS}
 									on:change={async () => {
 									on:change={async () => {
-										updateDirectConnectionsHandler();
+										updateConnectionsHandler();
 									}}
 									}}
 								/>
 								/>
 							</div>
 							</div>
@@ -383,6 +381,31 @@
 						)}
 						)}
 					</div>
 					</div>
 				</div>
 				</div>
+
+				<hr class=" border-gray-100 dark:border-gray-850 my-2" />
+
+				<div class="my-2">
+					<div class="flex justify-between items-center text-sm">
+						<div class=" text-xs font-medium">{$i18n.t('Cache Model List')}</div>
+
+						<div class="flex items-center">
+							<div class="">
+								<Switch
+									bind:state={connectionsConfig.ENABLE_MODEL_LIST_CACHE}
+									on:change={async () => {
+										updateConnectionsHandler();
+									}}
+								/>
+							</div>
+						</div>
+					</div>
+
+					<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
+						{$i18n.t(
+							'Model List Cache allows for faster access to model information by caching it locally.'
+						)}
+					</div>
+				</div>
 			</div>
 			</div>
 		{:else}
 		{:else}
 			<div class="flex h-full justify-center">
 			<div class="flex h-full justify-center">