|
@@ -36,7 +36,6 @@ from fastapi import (
|
|
|
applications,
|
|
|
BackgroundTasks,
|
|
|
)
|
|
|
-
|
|
|
from fastapi.openapi.docs import get_swagger_ui_html
|
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
|
from starlette.responses import Response, StreamingResponse
|
|
|
+from starlette.datastructures import Headers
|
|
|
|
|
|
|
|
|
from open_webui.utils import logger
|
|
@@ -116,6 +116,8 @@ from open_webui.config import (
|
|
|
OPENAI_API_CONFIGS,
|
|
|
# Direct Connections
|
|
|
ENABLE_DIRECT_CONNECTIONS,
|
|
|
+ # Model list
|
|
|
+ ENABLE_MODEL_LIST_CACHE,
|
|
|
# Thread pool size for FastAPI/AnyIO
|
|
|
THREAD_POOL_SIZE,
|
|
|
# Tool Server Configs
|
|
@@ -534,6 +536,27 @@ async def lifespan(app: FastAPI):
|
|
|
|
|
|
asyncio.create_task(periodic_usage_pool_cleanup())
|
|
|
|
|
|
+ if app.state.config.ENABLE_MODEL_LIST_CACHE:
|
|
|
+ get_all_models(
|
|
|
+ Request(
|
|
|
+ # Creating a mock request object to pass to get_all_models
|
|
|
+ {
|
|
|
+ "type": "http",
|
|
|
+ "asgi.version": "3.0",
|
|
|
+ "asgi.spec_version": "2.0",
|
|
|
+ "method": "GET",
|
|
|
+ "path": "/internal",
|
|
|
+ "query_string": b"",
|
|
|
+ "headers": Headers({}).raw,
|
|
|
+ "client": ("127.0.0.1", 12345),
|
|
|
+ "server": ("127.0.0.1", 80),
|
|
|
+ "scheme": "http",
|
|
|
+ "app": app,
|
|
|
+ }
|
|
|
+ ),
|
|
|
+ None,
|
|
|
+ )
|
|
|
+
|
|
|
yield
|
|
|
|
|
|
if hasattr(app.state, "redis_task_command_listener"):
|
|
@@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = []
|
|
|
|
|
|
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
|
|
|
|
|
+########################################
|
|
|
+#
|
|
|
+# MODEL LIST
|
|
|
+#
|
|
|
+########################################
|
|
|
+
|
|
|
+app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
|
|
|
+
|
|
|
########################################
|
|
|
#
|
|
|
# WEBUI
|
|
@@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE:
|
|
|
|
|
|
|
|
|
@app.get("/api/models")
|
|
|
-async def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
+async def get_models(
|
|
|
+ request: Request, refresh: bool = False, user=Depends(get_verified_user)
|
|
|
+):
|
|
|
def get_filtered_models(models, user):
|
|
|
filtered_models = []
|
|
|
for model in models:
|
|
@@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
|
|
|
return filtered_models
|
|
|
|
|
|
- all_models = await get_all_models(request, user=user)
|
|
|
+ if request.app.state.MODELS and (
|
|
|
+ request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
|
|
|
+ ):
|
|
|
+ all_models = list(request.app.state.MODELS.values())
|
|
|
+ else:
|
|
|
+ all_models = await get_all_models(request, user=user)
|
|
|
|
|
|
models = []
|
|
|
for model in all_models:
|