Procházet zdrojové kódy

Add read/write access control for files from knowledge

tarmst před 4 měsíci
rodič
revize
1ad80490de
1 změnil soubory, kde provedl 104 přidání a 7 odebrání
  1. 104 7
      backend/open_webui/routers/files.py

+ 104 - 7
backend/open_webui/routers/files.py

@@ -15,6 +15,7 @@ from open_webui.models.files import (
     FileModelResponse,
     Files,
 )
+from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
 from open_webui.routers.retrieval import ProcessFileForm, process_file
 from open_webui.routers.audio import transcribe
 from open_webui.storage.provider import Storage
@@ -27,6 +28,43 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
 router = APIRouter()
 
+############################
+# Check if the current user has access to a file through any knowledge bases the user may be in.
+############################
+async def check_user_has_access_to_file_via_any_knowledge_base(file_id: Optional[str], access_type: str, user=Depends(get_verified_user)) -> bool:
+    file = Files.get_file_by_id(file_id)
+    log.debug(f"Checking if user has {access_type} access to file")
+
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_access = False
+    knowledge_base_id = file.meta.get("collection_name") if file.meta else None
+    log.debug(f"Knowledge base associated with file: {knowledge_base_id}")
+    if knowledge_base_id:
+        if access_type == "read":
+            user_access = await get_knowledge(user=user) # get_knowledge checks for read access
+        elif access_type == "write":
+            user_access = await get_knowledge_list(user=user) # get_knowledge_list checks for write access
+        else:
+            user_access = list()
+        
+        for knowledge_base in user_access:
+            if knowledge_base.id == knowledge_base_id:
+                log.debug(f"User knowledge base with {access_type} access {knowledge_base.id} == File knowledge base {knowledge_base_id}")
+                has_access = True
+                break
+
+    
+    log.debug(f"Does user have {access_type} access to file: {has_access}")
+
+    return has_access
+    
+
+
 ############################
 # Upload File
 ############################
@@ -160,7 +198,15 @@ async def delete_all_files(user=Depends(get_admin_user)):
 async def get_file_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
 
-    if file and (file.user_id == user.id or user.role == "admin"):
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_read_access:
         return file
     else:
         raise HTTPException(
@@ -178,7 +224,15 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
 async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
 
-    if file and (file.user_id == user.id or user.role == "admin"):
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_read_access:
         return {"content": file.data.get("content", "")}
     else:
         raise HTTPException(
@@ -202,7 +256,15 @@ async def update_file_data_content_by_id(
 ):
     file = Files.get_file_by_id(id)
 
-    if file and (file.user_id == user.id or user.role == "admin"):
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_write_access:
         try:
             process_file(
                 request,
@@ -230,7 +292,16 @@ async def update_file_data_content_by_id(
 @router.get("/{id}/content")
 async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
-    if file and (file.user_id == user.id or user.role == "admin"):
+
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_read_access:
         try:
             file_path = Storage.get_file(file.path)
             file_path = Path(file_path)
@@ -282,7 +353,16 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 @router.get("/{id}/content/html")
 async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
-    if file and (file.user_id == user.id or user.role == "admin"):
+
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_read_access:
         try:
             file_path = Storage.get_file(file.path)
             file_path = Path(file_path)
@@ -314,7 +394,15 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
 async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
 
-    if file and (file.user_id == user.id or user.role == "admin"):
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_read_access:
         file_path = file.path
 
         # Handle Unicode filenames
@@ -365,7 +453,16 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 @router.delete("/{id}")
 async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
     file = Files.get_file_by_id(id)
-    if file and (file.user_id == user.id or user.role == "admin"):
+
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+    
+    has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user)
+
+    if file.user_id == user.id or user.role == "admin" or has_write_access:
         # We should add Chroma cleanup here
 
         result = Files.delete_file_by_id(id)