Ver código fonte

refac: MODEL_LIST_CACHE_TTL -> MODELS_CACHE_TTL

Timothy Jaeryang Baek 3 meses atrás
pai
commit
2b88f66762

+ 17 - 8
backend/open_webui/env.py

@@ -399,18 +399,27 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
     os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
 )
 
-ENABLE_WEBSOCKET_SUPPORT = (
-    os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
-)
+####################################
+# MODELS
+####################################
 
-MODEL_LIST_CACHE_TTL = os.environ.get("MODEL_LIST_CACHE_TTL", "1")
-if MODEL_LIST_CACHE_TTL == "":
-    MODEL_LIST_CACHE_TTL = None
+MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
+if MODELS_CACHE_TTL == "":
+    MODELS_CACHE_TTL = None
 else:
     try:
-        MODEL_LIST_CACHE_TTL = int(MODEL_LIST_CACHE_TTL)
+        MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
     except Exception:
-        MODEL_LIST_CACHE_TTL = 1
+        MODELS_CACHE_TTL = 1
+
+
+####################################
+# WEBSOCKET SUPPORT
+####################################
+
+ENABLE_WEBSOCKET_SUPPORT = (
+    os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
+)
 
 
 WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")

+ 2 - 2
backend/open_webui/routers/ollama.py

@@ -59,7 +59,7 @@ from open_webui.config import (
 from open_webui.env import (
     ENV,
     SRC_LOG_LEVELS,
-    MODEL_LIST_CACHE_TTL,
+    MODELS_CACHE_TTL,
     AIOHTTP_CLIENT_SESSION_SSL,
     AIOHTTP_CLIENT_TIMEOUT,
     AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -331,7 +331,7 @@ def merge_ollama_models_lists(model_lists):
     return list(merged_models.values())
 
 
-@cached(ttl=MODEL_LIST_CACHE_TTL)
+@cached(ttl=MODELS_CACHE_TTL)
 async def get_all_models(request: Request, user: UserModel = None):
     log.info("get_all_models()")
     if request.app.state.config.ENABLE_OLLAMA_API:

+ 2 - 2
backend/open_webui/routers/openai.py

@@ -21,7 +21,7 @@ from open_webui.config import (
     CACHE_DIR,
 )
 from open_webui.env import (
-    MODEL_LIST_CACHE_TTL,
+    MODELS_CACHE_TTL,
     AIOHTTP_CLIENT_SESSION_SSL,
     AIOHTTP_CLIENT_TIMEOUT,
     AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@@ -387,7 +387,7 @@ async def get_filtered_models(models, user):
     return filtered_models
 
 
-@cached(ttl=MODEL_LIST_CACHE_TTL)
+@cached(ttl=MODELS_CACHE_TTL)
 async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
     log.info("get_all_models()")