浏览代码

Merge pull request #13670 from HarrisonConsulting/fix/milvus-standalone-index

fix: enhance MilvusClient with dynamic index type and improved logging
Tim Jaeryang Baek 5 月之前
父节点
当前提交
1fea4f794f
共有 2 个文件被更改,包括 115 次插入49 次删除
  1. 6 0
      backend/open_webui/config.py
  2. 109 49
      backend/open_webui/retrieval/vector/dbs/milvus.py

+ 6 - 0
backend/open_webui/config.py

@@ -1765,6 +1765,12 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
 MILVUS_DB = os.environ.get("MILVUS_DB", "default")
 MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
 
+MILVUS_INDEX_TYPE = os.environ.get("MILVUS_INDEX_TYPE", "HNSW")
+MILVUS_METRIC_TYPE = os.environ.get("MILVUS_METRIC_TYPE", "COSINE")
+MILVUS_HNSW_M = int(os.environ.get("MILVUS_HNSW_M", "16"))
+MILVUS_HNSW_EFCONSTRUCTION = int(os.environ.get("MILVUS_HNSW_EFCONSTRUCTION", "100"))
+MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128"))
+
 # Qdrant
 QDRANT_URI = os.environ.get("QDRANT_URI", None)
 QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)

+ 109 - 49
backend/open_webui/retrieval/vector/dbs/milvus.py

@@ -3,7 +3,6 @@ from pymilvus import FieldSchema, DataType
 import json
 import logging
 from typing import Optional
-
 from open_webui.retrieval.vector.main import (
     VectorDBBase,
     VectorItem,
@@ -14,13 +13,17 @@ from open_webui.config import (
     MILVUS_URI,
     MILVUS_DB,
     MILVUS_TOKEN,
+    MILVUS_INDEX_TYPE,
+    MILVUS_METRIC_TYPE,
+    MILVUS_HNSW_M,
+    MILVUS_HNSW_EFCONSTRUCTION,
+    MILVUS_IVF_FLAT_NLIST,
 )
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
-
 class MilvusClient(VectorDBBase):
     def __init__(self):
         self.collection_prefix = "open_webui"
@@ -33,7 +36,6 @@ class MilvusClient(VectorDBBase):
         ids = []
         documents = []
         metadatas = []
-
         for match in result:
             _ids = []
             _documents = []
@@ -42,11 +44,9 @@ class MilvusClient(VectorDBBase):
                 _ids.append(item.get("id"))
                 _documents.append(item.get("data", {}).get("text"))
                 _metadatas.append(item.get("metadata"))
-
             ids.append(_ids)
             documents.append(_documents)
             metadatas.append(_metadatas)
-
         return GetResult(
             **{
                 "ids": ids,
@@ -60,13 +60,11 @@ class MilvusClient(VectorDBBase):
         distances = []
         documents = []
         metadatas = []
-
         for match in result:
             _ids = []
             _distances = []
             _documents = []
             _metadatas = []
-
             for item in match:
                 _ids.append(item.get("id"))
                 # normalize milvus score from [-1, 1] to [0, 1] range
@@ -75,12 +73,10 @@ class MilvusClient(VectorDBBase):
                 _distances.append(_dist)
                 _documents.append(item.get("entity", {}).get("data", {}).get("text"))
                 _metadatas.append(item.get("entity", {}).get("metadata"))
-
             ids.append(_ids)
             distances.append(_distances)
             documents.append(_documents)
             metadatas.append(_metadatas)
-
         return SearchResult(
             **{
                 "ids": ids,
@@ -113,11 +109,36 @@ class MilvusClient(VectorDBBase):
         )
 
         index_params = self.client.prepare_index_params()
+
+        # Use configurations from config.py
+        index_type = MILVUS_INDEX_TYPE.upper()
+        metric_type = MILVUS_METRIC_TYPE.upper()
+        
+        log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
+
+        index_creation_params = {}
+        if index_type == "HNSW":
+            index_creation_params = {"M": MILVUS_HNSW_M, "efConstruction": MILVUS_HNSW_EFCONSTRUCTION}
+            log.info(f"HNSW params: {index_creation_params}")
+        elif index_type == "IVF_FLAT":
+            index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
+            log.info(f"IVF_FLAT params: {index_creation_params}")
+        elif index_type in ["FLAT", "AUTOINDEX"]:
+            log.info(f"Using {index_type} index with no specific build-time params.")
+        else:
+            log.warning(
+                f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
+                f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
+                f"Milvus will use its default for the collection if this type is not directly supported for index creation."
+            )
+            # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
+            # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
+
         index_params.add_index(
             field_name="vector",
-            index_type="HNSW",
-            metric_type="COSINE",
-            params={"M": 16, "efConstruction": 100},
+            index_type=index_type,
+            metric_type=metric_type,
+            params=index_creation_params,
         )
 
         self.client.create_collection(
@@ -125,6 +146,8 @@ class MilvusClient(VectorDBBase):
             schema=schema,
             index_params=index_params,
         )
+        log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'.")
+
 
     def has_collection(self, collection_name: str) -> bool:
         # Check if the collection exists based on the collection name.
@@ -145,84 +168,95 @@ class MilvusClient(VectorDBBase):
     ) -> Optional[SearchResult]:
         # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
         collection_name = collection_name.replace("-", "_")
+        # For some index types like IVF_FLAT, search params like nprobe can be set.
+        # Example: search_params = {"nprobe": 10} if using IVF_FLAT
+        # For simplicity, not adding configurable search_params here, but could be extended.
         result = self.client.search(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             data=vectors,
             limit=limit,
             output_fields=["data", "metadata"],
+            # search_params=search_params # Potentially add later if needed
         )
-
         return self._result_to_search_result(result)
 
     def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
         # Construct the filter string for querying
         collection_name = collection_name.replace("-", "_")
         if not self.has_collection(collection_name):
+            log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
             return None
-
         filter_string = " && ".join(
             [
                 f'metadata["{key}"] == {json.dumps(value)}'
                 for key, value in filter.items()
             ]
         )
-
         max_limit = 16383  # The maximum number of records per request
         all_results = []
-
         if limit is None:
-            limit = float("inf")  # Use infinity as a placeholder for no limit
+            # Milvus default limit for query if not specified is 16384, but docs mention iteration.
+            # Let's set a practical high number if "all" is intended, or handle true pagination.
+            # For now, if limit is None, we'll fetch in batches up to a very large number.
+            # This part could be refined based on expected use cases for "get all".
+            # For this function signature, None implies "as many as possible" up to Milvus limits.
+            limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
+            log.info(f"Limit not specified for query, fetching up to {limit} results in batches.")
+
 
         # Initialize offset and remaining to handle pagination
         offset = 0
         remaining = limit
-
+        
         try:
+            log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}")
             # Loop until there are no more items to fetch or the desired limit is reached
             while remaining > 0:
-                log.info(f"remaining: {remaining}")
-                current_fetch = min(
-                    max_limit, remaining
-                )  # Determine how many items to fetch in this iteration
-
+                current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit)
+                log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}")
+                
                 results = self.client.query(
                     collection_name=f"{self.collection_prefix}_{collection_name}",
                     filter=filter_string,
-                    output_fields=["*"],
+                    output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query.
                     limit=current_fetch,
                     offset=offset,
                 )
-
+                
                 if not results:
+                    log.debug("No more results from query.")
                     break
-
+                
                 all_results.extend(results)
                 results_count = len(results)
-                remaining -= (
-                    results_count  # Decrease remaining by the number of items fetched
-                )
-                offset += results_count
+                log.debug(f"Fetched {results_count} results in this batch.")
 
-                # Break the loop if the results returned are less than the requested fetch count
+                if isinstance(remaining, int):
+                    remaining -= results_count
+                
+                offset += results_count
+                
+                # Break the loop if the results returned are less than the requested fetch count (means end of data)
                 if results_count < current_fetch:
+                    log.debug("Fetched less than requested, assuming end of results for this query.")
                     break
-
-            log.debug(all_results)
+            
+            log.info(f"Total results from query: {len(all_results)}")
             return self._result_to_get_result([all_results])
         except Exception as e:
             log.exception(
-                f"Error querying collection {collection_name} with limit {limit}: {e}"
+                f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
             )
             return None
 
     def get(self, collection_name: str) -> Optional[GetResult]:
-        # Get all the items in the collection.
+        # Get all the items in the collection. This can be very resource-intensive for large collections.
         collection_name = collection_name.replace("-", "_")
-        result = self.client.query(
-            collection_name=f"{self.collection_prefix}_{collection_name}",
-            filter='id != ""',
-        )
-        return self._result_to_get_result([result])
+        log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.")
+        # Using query with a trivial filter to get all items.
+        # This will use the paginated query logic.
+        return self.query(collection_name=collection_name, filter={}, limit=None)
+
 
     def insert(self, collection_name: str, items: list[VectorItem]):
         # Insert the items into the collection, if the collection does not exist, it will be created.
@@ -230,10 +264,15 @@ class MilvusClient(VectorDBBase):
         if not self.client.has_collection(
             collection_name=f"{self.collection_prefix}_{collection_name}"
         ):
+            log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.")
+            if not items:
+                log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.")
+                raise ValueError("Cannot create Milvus collection without items to determine vector dimension.")
             self._create_collection(
                 collection_name=collection_name, dimension=len(items[0]["vector"])
             )
-
+        
+        log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
         return self.client.insert(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             data=[
@@ -253,10 +292,15 @@ class MilvusClient(VectorDBBase):
         if not self.client.has_collection(
             collection_name=f"{self.collection_prefix}_{collection_name}"
         ):
+            log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.")
+            if not items:
+                log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.")
+                raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.")
             self._create_collection(
                 collection_name=collection_name, dimension=len(items[0]["vector"])
             )
-
+        
+        log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
         return self.client.upsert(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             data=[
@@ -276,30 +320,46 @@ class MilvusClient(VectorDBBase):
         ids: Optional[list[str]] = None,
         filter: Optional[dict] = None,
     ):
-        # Delete the items from the collection based on the ids.
+        # Delete the items from the collection based on the ids or filter.
         collection_name = collection_name.replace("-", "_")
+        if not self.has_collection(collection_name):
+            log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
+            return None
+
         if ids:
+            log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}")
             return self.client.delete(
                 collection_name=f"{self.collection_prefix}_{collection_name}",
                 ids=ids,
             )
         elif filter:
-            # Convert the filter dictionary to a string using JSON_CONTAINS.
             filter_string = " && ".join(
                 [
                     f'metadata["{key}"] == {json.dumps(value)}'
                     for key, value in filter.items()
                 ]
             )
-
+            log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}")
             return self.client.delete(
                 collection_name=f"{self.collection_prefix}_{collection_name}",
                 filter=filter_string,
             )
+        else:
+            log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.")
+            return None
+
 
     def reset(self):
-        # Resets the database. This will delete all collections and item entries.
+        # Resets the database. This will delete all collections and item entries that match the prefix.
+        log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
         collection_names = self.client.list_collections()
-        for collection_name in collection_names:
-            if collection_name.startswith(self.collection_prefix):
-                self.client.drop_collection(collection_name=collection_name)
+        deleted_collections = []
+        for collection_name_full in collection_names:
+            if collection_name_full.startswith(self.collection_prefix):
+                try:
+                    self.client.drop_collection(collection_name=collection_name_full)
+                    deleted_collections.append(collection_name_full)
+                    log.info(f"Deleted collection: {collection_name_full}")
+                except Exception as e:
+                    log.error(f"Error deleting collection {collection_name_full}: {e}")
+        log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")