浏览代码

enh: ENABLE_MODEL_LIST_CACHE

Timothy Jaeryang Baek 3 月之前
父节点
当前提交
1a52585769

+ 12 - 0
backend/open_webui/config.py

@@ -923,6 +923,18 @@ except Exception:
     pass
 OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 
+
+####################################
+# MODEL_LIST
+####################################
+
+ENABLE_MODEL_LIST_CACHE = PersistentConfig(
+    "ENABLE_MODEL_LIST_CACHE",
+    "models.cache",
+    os.environ.get("ENABLE_MODEL_LIST_CACHE", "False").lower() == "true",
+)
+
+
 ####################################
 # TOOL_SERVERS
 ####################################

+ 41 - 3
backend/open_webui/main.py

@@ -36,7 +36,6 @@ from fastapi import (
     applications,
     BackgroundTasks,
 )
-
 from fastapi.openapi.docs import get_swagger_ui_html
 
 from fastapi.middleware.cors import CORSMiddleware
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import Response, StreamingResponse
+from starlette.datastructures import Headers
 
 
 from open_webui.utils import logger
@@ -116,6 +116,8 @@ from open_webui.config import (
     OPENAI_API_CONFIGS,
     # Direct Connections
     ENABLE_DIRECT_CONNECTIONS,
+    # Model list
+    ENABLE_MODEL_LIST_CACHE,
     # Thread pool size for FastAPI/AnyIO
     THREAD_POOL_SIZE,
     # Tool Server Configs
@@ -534,6 +536,27 @@ async def lifespan(app: FastAPI):
 
     asyncio.create_task(periodic_usage_pool_cleanup())
 
+    if app.state.config.ENABLE_MODEL_LIST_CACHE:
+        get_all_models(
+            Request(
+                # Creating a mock request object to pass to get_all_models
+                {
+                    "type": "http",
+                    "asgi.version": "3.0",
+                    "asgi.spec_version": "2.0",
+                    "method": "GET",
+                    "path": "/internal",
+                    "query_string": b"",
+                    "headers": Headers({}).raw,
+                    "client": ("127.0.0.1", 12345),
+                    "server": ("127.0.0.1", 80),
+                    "scheme": "http",
+                    "app": app,
+                }
+            ),
+            None,
+        )
+
     yield
 
     if hasattr(app.state, "redis_task_command_listener"):
@@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = []
 
 app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
 
+########################################
+#
+# MODEL LIST
+#
+########################################
+
+app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
+
 ########################################
 #
 # WEBUI
@@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE:
 
 
 @app.get("/api/models")
-async def get_models(request: Request, user=Depends(get_verified_user)):
+async def get_models(
+    request: Request, refresh: bool = False, user=Depends(get_verified_user)
+):
     def get_filtered_models(models, user):
         filtered_models = []
         for model in models:
@@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
 
         return filtered_models
 
-    all_models = await get_all_models(request, user=user)
+    if request.app.state.MODELS and (
+        request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
+    ):
+        all_models = list(request.app.state.MODELS.values())
+    else:
+        all_models = await get_all_models(request, user=user)
 
     models = []
     for model in all_models:

+ 12 - 7
backend/open_webui/routers/configs.py

@@ -39,32 +39,37 @@ async def export_config(user=Depends(get_admin_user)):
 
 
 ############################
-# Direct Connections Config
+# Connections Config
 ############################
 
 
-class DirectConnectionsConfigForm(BaseModel):
+class ConnectionsConfigForm(BaseModel):
     ENABLE_DIRECT_CONNECTIONS: bool
+    ENABLE_MODEL_LIST_CACHE: bool
 
 
-@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
+@router.get("/connections", response_model=ConnectionsConfigForm)
+async def get_connections_config(request: Request, user=Depends(get_admin_user)):
     return {
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+        "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
     }
 
 
-@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
-async def set_direct_connections_config(
+@router.post("/connections", response_model=ConnectionsConfigForm)
+async def set_connections_config(
     request: Request,
-    form_data: DirectConnectionsConfigForm,
+    form_data: ConnectionsConfigForm,
     user=Depends(get_admin_user),
 ):
     request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
         form_data.ENABLE_DIRECT_CONNECTIONS
     )
+    request.app.state.config.ENABLE_MODEL_LIST_CACHE = form_data.ENABLE_MODEL_LIST_CACHE
+
     return {
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
+        "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
     }
 
 

+ 4 - 4
src/lib/apis/configs/index.ts

@@ -58,10 +58,10 @@ export const exportConfig = async (token: string) => {
 	return res;
 };
 
-export const getDirectConnectionsConfig = async (token: string) => {
+export const getConnectionsConfig = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
 		method: 'GET',
 		headers: {
 			'Content-Type': 'application/json',
@@ -85,10 +85,10 @@ export const getDirectConnectionsConfig = async (token: string) => {
 	return res;
 };
 
-export const setDirectConnectionsConfig = async (token: string, config: object) => {
+export const setConnectionsConfig = async (token: string, config: object) => {
 	let error = null;
 
-	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'application/json',

+ 36 - 13
src/lib/components/admin/Settings/Connections.svelte

@@ -7,7 +7,7 @@
 	import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama';
 	import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai';
 	import { getModels as _getModels } from '$lib/apis';
-	import { getDirectConnectionsConfig, setDirectConnectionsConfig } from '$lib/apis/configs';
+	import { getConnectionsConfig, setConnectionsConfig } from '$lib/apis/configs';
 
 	import { config, models, settings, user } from '$lib/stores';
 
@@ -43,7 +43,7 @@
 	let ENABLE_OPENAI_API: null | boolean = null;
 	let ENABLE_OLLAMA_API: null | boolean = null;
 
-	let directConnectionsConfig = null;
+	let connectionsConfig = null;
 
 	let pipelineUrls = {};
 	let showAddOpenAIConnectionModal = false;
@@ -106,15 +106,13 @@
 		}
 	};
 
-	const updateDirectConnectionsHandler = async () => {
-		const res = await setDirectConnectionsConfig(localStorage.token, directConnectionsConfig).catch(
-			(error) => {
-				toast.error(`${error}`);
-			}
-		);
+	const updateConnectionsHandler = async () => {
+		const res = await setConnectionsConfig(localStorage.token, connectionsConfig).catch((error) => {
+			toast.error(`${error}`);
+		});
 
 		if (res) {
-			toast.success($i18n.t('Direct Connections settings updated'));
+			toast.success($i18n.t('Connections settings updated'));
 			await models.set(await getModels());
 		}
 	};
@@ -150,7 +148,7 @@
 					openaiConfig = await getOpenAIConfig(localStorage.token);
 				})(),
 				(async () => {
-					directConnectionsConfig = await getDirectConnectionsConfig(localStorage.token);
+					connectionsConfig = await getConnectionsConfig(localStorage.token);
 				})()
 			]);
 
@@ -217,7 +215,7 @@
 
 <form class="flex flex-col h-full justify-between text-sm" on:submit|preventDefault={submitHandler}>
 	<div class=" overflow-y-scroll scrollbar-hidden h-full">
-		{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && directConnectionsConfig !== null}
+		{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && connectionsConfig !== null}
 			<div class="mb-3.5">
 				<div class=" mb-2.5 text-base font-medium">{$i18n.t('General')}</div>
 
@@ -368,9 +366,9 @@
 						<div class="flex items-center">
 							<div class="">
 								<Switch
-									bind:state={directConnectionsConfig.ENABLE_DIRECT_CONNECTIONS}
+									bind:state={connectionsConfig.ENABLE_DIRECT_CONNECTIONS}
 									on:change={async () => {
-										updateDirectConnectionsHandler();
+										updateConnectionsHandler();
 									}}
 								/>
 							</div>
@@ -383,6 +381,31 @@
 						)}
 					</div>
 				</div>
+
+				<hr class=" border-gray-100 dark:border-gray-850 my-2" />
+
+				<div class="my-2">
+					<div class="flex justify-between items-center text-sm">
+						<div class=" text-xs font-medium">{$i18n.t('Cache Model List')}</div>
+
+						<div class="flex items-center">
+							<div class="">
+								<Switch
+									bind:state={connectionsConfig.ENABLE_MODEL_LIST_CACHE}
+									on:change={async () => {
+										updateConnectionsHandler();
+									}}
+								/>
+							</div>
+						</div>
+					</div>
+
+					<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
+						{$i18n.t(
+							'Model List Cache allows for faster access to model information by caching it locally.'
+						)}
+					</div>
+				</div>
 			</div>
 		{:else}
 			<div class="flex h-full justify-center">