浏览代码

refac/fix: milvus query logic

Timothy Jaeryang Baek 1 月之前
父节点
当前提交
ad98d4300b
共有 1 个文件被更改,包括 22 次插入54 次删除
  1. 22 54
      backend/open_webui/retrieval/vector/dbs/milvus.py

+ 22 - 54
backend/open_webui/retrieval/vector/dbs/milvus.py

@@ -1,5 +1,7 @@
 from pymilvus import MilvusClient as Client
 from pymilvus import MilvusClient as Client
 from pymilvus import FieldSchema, DataType
 from pymilvus import FieldSchema, DataType
+from pymilvus import connections, Collection
+
 import json
 import json
 import logging
 import logging
 from typing import Optional
 from typing import Optional
@@ -188,6 +190,8 @@ class MilvusClient(VectorDBBase):
         return self._result_to_search_result(result)
         return self._result_to_search_result(result)
 
 
     def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
     def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
+        connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
+
         # Construct the filter string for querying
         # Construct the filter string for querying
         collection_name = collection_name.replace("-", "_")
         collection_name = collection_name.replace("-", "_")
         if not self.has_collection(collection_name):
         if not self.has_collection(collection_name):
@@ -201,72 +205,36 @@ class MilvusClient(VectorDBBase):
                 for key, value in filter.items()
                 for key, value in filter.items()
             ]
             ]
         )
         )
-        max_limit = 16383  # The maximum number of records per request
-        all_results = []
-        if limit is None:
-            # Milvus default limit for query if not specified is 16384, but docs mention iteration.
-            # Let's set a practical high number if "all" is intended, or handle true pagination.
-            # For now, if limit is None, we'll fetch in batches up to a very large number.
-            # This part could be refined based on expected use cases for "get all".
-            # For this function signature, None implies "as many as possible" up to Milvus limits.
-            limit = (
-                16384 * 10
-            )  # A large number to signify fetching many, will be capped by actual data or max_limit per call.
-            log.info(
-                f"Limit not specified for query, fetching up to {limit} results in batches."
-            )
 
 
-        # Initialize offset and remaining to handle pagination
-        offset = 0
-        remaining = limit
+        collection = Collection(f"{self.collection_prefix}_{collection_name}")
+        collection.load()
+        all_results = []
 
 
         try:
         try:
             log.info(
             log.info(
                 f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
                 f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
             )
             )
-            # Loop until there are no more items to fetch or the desired limit is reached
-            while remaining > 0:
-                current_fetch = min(
-                    max_limit, remaining if isinstance(remaining, int) else max_limit
-                )
-                log.debug(
-                    f"Querying with offset: {offset}, current_fetch: {current_fetch}"
-                )
 
 
-                results = self.client.query(
-                    collection_name=f"{self.collection_prefix}_{collection_name}",
-                    filter=filter_string,
-                    output_fields=[
-                        "id",
-                        "data",
-                        "metadata",
-                    ],  # Explicitly list needed fields. Vector not usually needed in query.
-                    limit=current_fetch,
-                    offset=offset,
-                )
-
-                if not results:
-                    log.debug("No more results from query.")
-                    break
-
-                all_results.extend(results)
-                results_count = len(results)
-                log.debug(f"Fetched {results_count} results in this batch.")
-
-                if isinstance(remaining, int):
-                    remaining -= results_count
-
-                offset += results_count
+            iterator = collection.query_iterator(
+                filter=filter_string,
+                output_fields=[
+                    "id",
+                    "data",
+                    "metadata",
+                ],
+                limit=limit,  # Pass the limit directly; None means no limit.
+            )
 
 
-                # Break the loop if the results returned are less than the requested fetch count (means end of data)
-                if results_count < current_fetch:
-                    log.debug(
-                        "Fetched less than requested, assuming end of results for this query."
-                    )
+            while True:
+                result = iterator.next()
+                if not result:
+                    iterator.close()
                     break
                     break
+                all_results += result
 
 
             log.info(f"Total results from query: {len(all_results)}")
             log.info(f"Total results from query: {len(all_results)}")
             return self._result_to_get_result([all_results])
             return self._result_to_get_result([all_results])
+
         except Exception as e:
         except Exception as e:
             log.exception(
             log.exception(
                 f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
                 f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"