Browse Source

feat: model filter backend

Timothy J. Baek 1 year ago
parent
commit
b550e23bf6

+ 13 - 2
backend/apps/ollama/main.py

@@ -29,6 +29,10 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
+app.state.MODEL_FILTER_ENABLED = False
+app.state.MODEL_LIST = []
+
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 
@@ -129,9 +133,16 @@ async def get_all_models():
 async def get_ollama_tags(
     url_idx: Optional[int] = None, user=Depends(get_current_user)
 ):
-
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["models"] = filter(
+                    lambda model: model["name"] in app.state.MODEL_LIST,
+                    models["models"],
+                )
+                return models
+        return models
     else:
         url = app.state.OLLAMA_BASE_URLS[url_idx]
         try:

+ 13 - 3
backend/apps/openai/main.py

@@ -34,6 +34,9 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+app.state.MODEL_FILTER_ENABLED = False
+app.state.MODEL_LIST = []
+
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 
@@ -186,12 +189,19 @@ async def get_all_models():
     return models
 
 
-# , user=Depends(get_current_user)
 @app.get("/models")
 @app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None):
+async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
     if url_idx == None:
-        return await get_all_models()
+        models = await get_all_models()
+        if app.state.MODEL_FILTER_ENABLED:
+            if user.role == "user":
+                models["data"] = filter(
+                    lambda model: model["id"] in app.state.MODEL_LIST,
+                    models["data"],
+                )
+                return models
+        return models
     else:
         url = app.state.OPENAI_API_BASE_URLS[url_idx]
         try:

+ 34 - 0
backend/main.py

@@ -23,7 +23,11 @@ from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 
+from pydantic import BaseModel
+from typing import List
 
+
+from utils.utils import get_admin_user
 from apps.rag.utils import query_doc, query_collection, rag_template
 
 from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
@@ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles):
 
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 
+app.state.MODEL_FILTER_ENABLED = False
+app.state.MODEL_LIST = []
+
 origins = ["*"]
 
 app.add_middleware(
@@ -211,6 +218,33 @@ async def get_app_config():
     }
 
 
+@app.get("/api/config/model/filter")
+async def get_model_filter_config(user=Depends(get_admin_user)):
+    return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
+
+
+class ModelFilterConfigForm(BaseModel):
+    enabled: bool
+    models: List[str]
+
+
+@app.post("/api/config/model/filter")
+async def get_model_filter_config(
+    form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
+):
+
+    app.state.MODEL_FILTER_ENABLED = form_data.enabled
+    app.state.MODEL_LIST = form_data.models
+
+    ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
+
+    openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
+    openai_app.state.MODEL_LIST = app.state.MODEL_LIST
+
+    return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
+
+
 @app.get("/api/version")
 async def get_app_config():
 

+ 1 - 1
src/lib/components/chat/MessageInput.svelte

@@ -19,7 +19,7 @@
 
 	export let suggestionPrompts = [];
 	export let autoScroll = true;
-	let chatTextAreaElement:HTMLTextAreaElement
+	let chatTextAreaElement: HTMLTextAreaElement;
 	let filesInputElement;
 
 	let promptsElement;