瀏覽代碼

refac: retain metadata for collection

Timothy J. Baek 7 月之前
父節點
當前提交
1f9b5b6456

+ 32 - 17
backend/open_webui/apps/retrieval/main.py

@@ -733,15 +733,10 @@ def process_file(
         file = Files.get_file_by_id(form_data.file_id)
 
         collection_name = form_data.collection_name
+
         if collection_name is None:
             collection_name = f"file-{file.id}"
 
-        loader = Loader(
-            engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
-            TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
-            PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
-        )
-
         if form_data.content:
             docs = [
                 Document(
@@ -755,21 +750,41 @@ def process_file(
             ]
 
             text_content = form_data.content
-        elif file.data.get("content", None):
-            docs = [
-                Document(
-                    page_content=file.data.get("content", ""),
-                    metadata={
-                        "name": file.meta.get("name", file.filename),
-                        "created_by": file.user_id,
-                        **file.meta,
-                    },
-                )
-            ]
+        elif form_data.collection_name:
+            result = VECTOR_DB_CLIENT.query(
+                collection_name=f"file-{file.id}", filter={"file_id": file.id}
+            )
+
+            if result:
+                docs = [
+                    Document(
+                        page_content=result.documents[0][idx],
+                        metadata=result.metadatas[0][idx],
+                    )
+                    for idx, id in enumerate(result.ids[0])
+                ]
+            else:
+                docs = [
+                    Document(
+                        page_content=file.data.get("content", ""),
+                        metadata={
+                            "name": file.meta.get("name", file.filename),
+                            "created_by": file.user_id,
+                            **file.meta,
+                        },
+                    )
+                ]
+
             text_content = file.data.get("content", "")
         else:
             file_path = file.meta.get("path", None)
             if file_path:
+                loader = Loader(
+                    engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
+                    TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
+                    PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
+                )
+
                 docs = loader.load(
                     file.filename, file.meta.get("content_type"), file_path
                 )

+ 1 - 4
backend/open_webui/apps/retrieval/vector/dbs/chroma.py

@@ -70,10 +70,9 @@ class ChromaClient:
             return None
 
     def query(
-        self, collection_name: str, filter: dict, limit: int = 2
+        self, collection_name: str, filter: dict, limit: Optional[int] = None
     ) -> Optional[GetResult]:
         # Query the items from the collection based on the filter.
-
         try:
             collection = self.client.get_collection(name=collection_name)
             if collection:
@@ -82,8 +81,6 @@ class ChromaClient:
                     limit=limit,
                 )
 
-                print(result)
-
                 return GetResult(
                     **{
                         "ids": [result["ids"]],

+ 40 - 10
backend/open_webui/apps/retrieval/vector/dbs/milvus.py

@@ -135,10 +135,8 @@ class MilvusClient:
 
         return self._result_to_search_result(result)
 
-    def query(
-        self, collection_name: str, filter: dict, limit: int = 1
-    ) -> Optional[GetResult]:
-        # Query the items from the collection based on the filter.
+    def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
+        # Construct the filter string for querying
         filter_string = " && ".join(
             [
                 f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
@@ -146,13 +144,45 @@ class MilvusClient:
             ]
         )
 
-        result = self.client.query(
-            collection_name=f"{self.collection_prefix}_{collection_name}",
-            filter=filter_string,
-            limit=limit,
-        )
+        max_limit = 16383  # The maximum number of records per request
+        all_results = []
 
-        return self._result_to_get_result([result])
+        if limit is None:
+            limit = float("inf")  # Use infinity as a placeholder for no limit
+
+        # Initialize offset and remaining to handle pagination
+        offset = 0
+        remaining = limit
+
+        # Loop until there are no more items to fetch or the desired limit is reached
+        while remaining > 0:
+            current_fetch = min(
+                max_limit, remaining
+            )  # Determine how many items to fetch in this iteration
+
+            results = self.client.query(
+                collection_name=f"{self.collection_prefix}_{collection_name}",
+                filter=filter_string,
+                output_fields=["*"],
+                limit=current_fetch,
+                offset=offset,
+            )
+
+            if not results:
+                break
+
+            all_results.extend(results)
+            results_count = len(results)
+            remaining -= (
+                results_count  # Decrease remaining by the number of items fetched
+            )
+            offset += results_count
+
+            # Break the loop if the results returned are less than the requested fetch count
+            if results_count < current_fetch:
+                break
+
+        return self._result_to_get_result(all_results)
 
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.