Bladeren bron

feat: `RAG_ALLOWED_FILE_EXTENSIONS`

Timothy Jaeryang Baek 4 maanden geleden
bovenliggende
commit
a6624a4b16
3 gewijzigde bestanden met toevoegingen van 18 en 0 verwijderingen
  1. 6 0
      backend/open_webui/config.py
  2. 2 0
      backend/open_webui/main.py
  3. 10 0
      backend/open_webui/routers/files.py

+ 6 - 0
backend/open_webui/config.py

@@ -1951,6 +1951,12 @@ RAG_FILE_MAX_SIZE = PersistentConfig(
     ),
 )
 
+RAG_ALLOWED_FILE_EXTENSIONS = PersistentConfig(
+    "RAG_ALLOWED_FILE_EXTENSIONS",
+    "rag.file.allowed_extensions",
+    os.environ.get("RAG_ALLOWED_FILE_EXTENSIONS", "").split(","),
+)
+
 RAG_EMBEDDING_ENGINE = PersistentConfig(
     "RAG_EMBEDDING_ENGINE",
     "rag.embedding_engine",

+ 2 - 0
backend/open_webui/main.py

@@ -197,6 +197,7 @@ from open_webui.config import (
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_BATCH_SIZE,
     RAG_RELEVANCE_THRESHOLD,
+    RAG_ALLOWED_FILE_EXTENSIONS,
     RAG_FILE_MAX_COUNT,
     RAG_FILE_MAX_SIZE,
     RAG_OPENAI_API_BASE_URL,
@@ -638,6 +639,7 @@ app.state.FUNCTIONS = {}
 app.state.config.TOP_K = RAG_TOP_K
 app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
 app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
+app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
 app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
 app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
 

+ 10 - 0
backend/open_webui/routers/files.py

@@ -95,6 +95,16 @@ def upload_file(
         unsanitized_filename = file.filename
         filename = os.path.basename(unsanitized_filename)
 
+        file_extension = os.path.splitext(filename)[1]
+        if request.app.state.config.ALLOWED_FILE_EXTENSIONS:
+            if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(
+                        f"File type {file_extension} is not allowed"
+                    ),
+                )
+
         # replace filename with uuid
         id = str(uuid.uuid4())
         name = filename