Timothy Jaeryang Baek 3 ماه پیش
والد
کامیت
3b9d86de0b
2فایلهای تغییر یافته به همراه91 افزوده شده و 71 حذف شده
  1. 83 66
      backend/open_webui/retrieval/utils.py
  2. 8 5
      backend/open_webui/utils/middleware.py

+ 83 - 66
backend/open_webui/retrieval/utils.py

@@ -18,9 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 
 from open_webui.models.users import UserModel
 from open_webui.models.files import Files
+from open_webui.models.knowledge import Knowledges
 from open_webui.models.notes import Notes
 
 from open_webui.retrieval.vector.main import GetResult
+from open_webui.utils.access_control import has_access
 
 
 from open_webui.env import (
@@ -443,9 +445,9 @@ def get_embedding_function(
         raise ValueError(f"Unknown embedding engine: {embedding_engine}")
 
 
-def get_sources_from_files(
+def get_sources_from_items(
     request,
-    files,
+    items,
     queries,
     embedding_function,
     k,
@@ -455,75 +457,90 @@ def get_sources_from_files(
     hybrid_bm25_weight,
     hybrid_search,
     full_context=False,
+    user: Optional[UserModel] = None,
 ):
     log.debug(
-        f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
+        f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
     )
 
     extracted_collections = []
     query_results = []
 
-    for file in files:
+    for item in items:
         query_result = None
-        if file.get("docs"):
-            # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
-            query_result = {
-                "documents": [[doc.get("content") for doc in file.get("docs")]],
-                "metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
-            }
-        elif file.get("type") == "text":
+        if item.get("type") == "text":
             # Text File
+            # Used during temporary chat file uploads
             query_result = {
-                "documents": [[file.get("content")]],
-                "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
+                "documents": [[item.get("content")]],
+                "metadatas": [[{"file_id": item.get("id"), "name": item.get("name")}]],
             }
-        elif file.get("type") == "note":
+
+        elif item.get("type") == "note":
             # Note Attached
-            note = Notes.get_note_by_id(file.get("id"))
+            note = Notes.get_note_by_id(item.get("id"))
 
+            if user.role == "admin" or has_access(user.id, "read", note.access_control):
+                # User has access to the note
+                query_result = {
+                    "documents": [[note.data.get("content", {}).get("md", "")]],
+                    "metadatas": [[{"file_id": note.id, "name": note.title}]],
+                }
+
+        elif item.get("docs"):
+            # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
             query_result = {
-                "documents": [[note.data.get("content", {}).get("md", "")]],
-                "metadatas": [[{"file_id": note.id, "name": note.title}]],
+                "documents": [[doc.get("content") for doc in item.get("docs")]],
+                "metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
             }
-        elif file.get("context") == "full":
-            if file.get("type") == "file":
+
+        elif item.get("context") == "full":
+            if item.get("type") == "file":
                 # Manual Full Mode Toggle
+                # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
                 query_result = {
-                    "documents": [[file.get("file").get("data", {}).get("content")]],
+                    "documents": [[item.get("file").get("data", {}).get("content")]],
                     "metadatas": [
-                        [{"file_id": file.get("id"), "name": file.get("name")}]
+                        [{"file_id": item.get("id"), "name": item.get("name")}]
                     ],
                 }
-            elif file.get("type") == "collection":
+            elif item.get("type") == "collection":
                 # Manual Full Mode Toggle for Collection
-                file_ids = file.get("data", {}).get("file_ids", [])
+                knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
 
-                documents = []
-                metadatas = []
-                for file_id in file_ids:
-                    file_object = Files.get_file_by_id(file_id)
+                if knowledge_base and (
+                    user.role == "admin"
+                    or has_access(user.id, "read", knowledge_base.access_control)
+                ):
 
-                    if file_object:
-                        documents.append(file_object.data.get("content", ""))
-                        metadatas.append(
-                            {
-                                "file_id": file_id,
-                                "name": file_object.filename,
-                                "source": file_object.filename,
-                            }
-                        )
+                    file_ids = knowledge_base.data.get("file_ids", [])
 
-                query_result = {
-                    "documents": [documents],
-                    "metadatas": [metadatas],
-                }
+                    documents = []
+                    metadatas = []
+                    for file_id in file_ids:
+                        file_object = Files.get_file_by_id(file_id)
+
+                        if file_object:
+                            documents.append(file_object.data.get("content", ""))
+                            metadatas.append(
+                                {
+                                    "file_id": file_id,
+                                    "name": file_object.filename,
+                                    "source": file_object.filename,
+                                }
+                            )
+
+                    query_result = {
+                        "documents": [documents],
+                        "metadatas": [metadatas],
+                    }
         elif (
-            file.get("type") != "web_search"
+            item.get("type") != "web_search"
             and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
         ):
             # BYPASS_EMBEDDING_AND_RETRIEVAL
-            if file.get("type") == "collection":
-                file_ids = file.get("data", {}).get("file_ids", [])
+            if item.get("type") == "collection":
+                file_ids = item.get("data", {}).get("file_ids", [])
 
                 documents = []
                 metadatas = []
@@ -545,46 +562,46 @@ def get_sources_from_files(
                     "metadatas": [metadatas],
                 }
 
-            elif file.get("id"):
-                file_object = Files.get_file_by_id(file.get("id"))
+            elif item.get("id"):
+                file_object = Files.get_file_by_id(item.get("id"))
                 if file_object:
                     query_result = {
                         "documents": [[file_object.data.get("content", "")]],
                         "metadatas": [
                             [
                                 {
-                                    "file_id": file.get("id"),
+                                    "file_id": item.get("id"),
                                     "name": file_object.filename,
                                     "source": file_object.filename,
                                 }
                             ]
                         ],
                     }
-            elif file.get("file").get("data"):
+            elif item.get("file").get("data"):
                 query_result = {
-                    "documents": [[file.get("file").get("data", {}).get("content")]],
+                    "documents": [[item.get("file").get("data", {}).get("content")]],
                     "metadatas": [
-                        [file.get("file").get("data", {}).get("metadata", {})]
+                        [item.get("file").get("data", {}).get("metadata", {})]
                     ],
                 }
         else:
             collection_names = []
-            if file.get("type") == "collection":
-                if file.get("legacy"):
-                    collection_names = file.get("collection_names", [])
+            if item.get("type") == "collection":
+                if item.get("legacy"):
+                    collection_names = item.get("collection_names", [])
                 else:
-                    collection_names.append(file["id"])
-            elif file.get("collection_name"):
-                collection_names.append(file["collection_name"])
-            elif file.get("id"):
-                if file.get("legacy"):
-                    collection_names.append(f"{file['id']}")
+                    collection_names.append(item["id"])
+            elif item.get("collection_name"):
+                collection_names.append(item["collection_name"])
+            elif item.get("id"):
+                if item.get("legacy"):
+                    collection_names.append(f"{item['id']}")
                 else:
-                    collection_names.append(f"file-{file['id']}")
+                    collection_names.append(f"file-{item['id']}")
 
             collection_names = set(collection_names).difference(extracted_collections)
             if not collection_names:
-                log.debug(f"skipping {file} as it has already been extracted")
+                log.debug(f"skipping {item} as it has already been extracted")
                 continue
 
             if full_context:
@@ -596,14 +613,14 @@ def get_sources_from_files(
             else:
                 try:
                     query_result = None
-                    if file.get("type") == "text":
+                    if item.get("type") == "text":
                         # Not sure when this is used, but it seems to be a fallback
                         query_result = {
                             "documents": [
-                                [file.get("file").get("data", {}).get("content")]
+                                [item.get("file").get("data", {}).get("content")]
                             ],
                             "metadatas": [
-                                [file.get("file").get("data", {}).get("meta", {})]
+                                [item.get("file").get("data", {}).get("meta", {})]
                             ],
                         }
                     else:
@@ -638,10 +655,10 @@ def get_sources_from_files(
             extracted_collections.extend(collection_names)
 
         if query_result:
-            if "data" in file:
-                del file["data"]
+            if "data" in item:
+                del item["data"]
 
-            query_results.append({**query_result, "file": file})
+            query_results.append({**query_result, "file": item})
 
     sources = []
     for query_result in query_results:

+ 8 - 5
backend/open_webui/utils/middleware.py

@@ -56,7 +56,7 @@ from open_webui.models.users import UserModel
 from open_webui.models.functions import Functions
 from open_webui.models.models import Models
 
-from open_webui.retrieval.utils import get_sources_from_files
+from open_webui.retrieval.utils import get_sources_from_items
 
 
 from open_webui.utils.chat import generate_chat_completion
@@ -638,14 +638,14 @@ async def chat_completion_files_handler(
             queries = [get_last_user_message(body["messages"])]
 
         try:
-            # Offload get_sources_from_files to a separate thread
+            # Offload get_sources_from_items to a separate thread
             loop = asyncio.get_running_loop()
             with ThreadPoolExecutor() as executor:
                 sources = await loop.run_in_executor(
                     executor,
-                    lambda: get_sources_from_files(
+                    lambda: get_sources_from_items(
                         request=request,
-                        files=files,
+                        items=files,
                         queries=queries,
                         embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
                             query, prefix=prefix, user=user
@@ -657,6 +657,7 @@ async def chat_completion_files_handler(
                         hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
                         full_context=request.app.state.config.RAG_FULL_CONTEXT,
+                        user=user,
                     ),
                 )
         except Exception as e:
@@ -2152,7 +2153,9 @@ async def process_chat_response(
                         if isinstance(tool_result, dict) or isinstance(
                             tool_result, list
                         ):
-                            tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False)
+                            tool_result = json.dumps(
+                                tool_result, indent=2, ensure_ascii=False
+                            )
 
                         results.append(
                             {