Browse Source

fix: model filter issue

Timothy J. Baek 1 year ago
parent
commit
3aa6b0fea9
4 changed files with 21 additions and 18 deletions
  1. 4 4
      backend/apps/litellm/main.py
  2. 8 6
      backend/apps/ollama/main.py
  3. 5 4
      backend/apps/openai/main.py
  4. 4 4
      backend/main.py

+ 4 - 4
backend/apps/litellm/main.py

@@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
     litellm_config = yaml.safe_load(file)
 
 
+app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+
+
 app.state.ENABLE = ENABLE_LITELLM
 app.state.CONFIG = litellm_config
 
@@ -151,10 +155,6 @@ async def shutdown_litellm_background():
         background_process = None
 
 
-app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
-app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-
-
 @app.get("/")
 async def get_status():
     return {"status": True}

+ 8 - 6
backend/apps/ollama/main.py

@@ -64,8 +64,8 @@ app.add_middleware(
 
 app.state.config = AppConfig()
 
-app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
-app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
+app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
@@ -178,11 +178,12 @@ async def get_ollama_tags(
     if url_idx == None:
         models = await get_all_models()
 
-        if app.state.ENABLE_MODEL_FILTER:
+        if app.state.config.ENABLE_MODEL_FILTER:
             if user.role == "user":
                 models["models"] = list(
                     filter(
-                        lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
+                        lambda model: model["name"]
+                        in app.state.config.MODEL_FILTER_LIST,
                         models["models"],
                     )
                 )
@@ -1046,11 +1047,12 @@ async def get_openai_models(
     if url_idx == None:
         models = await get_all_models()
 
-        if app.state.ENABLE_MODEL_FILTER:
+        if app.state.config.ENABLE_MODEL_FILTER:
             if user.role == "user":
                 models["models"] = list(
                     filter(
-                        lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
+                        lambda model: model["name"]
+                        in app.state.config.MODEL_FILTER_LIST,
                         models["models"],
                     )
                 )

+ 5 - 4
backend/apps/openai/main.py

@@ -47,10 +47,11 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
 app.state.config = AppConfig()
 
-app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
-app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
+app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
+app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
@@ -259,11 +260,11 @@ async def get_all_models():
 async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
     if url_idx == None:
         models = await get_all_models()
-        if app.state.ENABLE_MODEL_FILTER:
+        if app.state.config.ENABLE_MODEL_FILTER:
             if user.role == "user":
                 models["data"] = list(
                     filter(
-                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
+                        lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
                         models["data"],
                     )
                 )

+ 4 - 4
backend/main.py

@@ -292,11 +292,11 @@ async def update_model_filter_config(
     app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
     app.state.config.MODEL_FILTER_LIST = form_data.models
 
-    ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
-    ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
+    ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
+    ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
 
-    openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
-    openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
+    openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
+    openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
 
     litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
     litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST