Ver Fonte

Merge pull request #17837 from Classic298/milvus-multitenancy

feat: Impelement Milvus multitenancy // breaking: set milvus multitenancy as standard option (just like Qdrant already is)
Tim Jaeryang Baek há 1 semana atrás
pai
commit
2d94b8e905

+ 5 - 2
backend/open_webui/config.py

@@ -2005,11 +2005,9 @@ if VECTOR_DB == "chroma":
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 
 # Milvus
-
 MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
 MILVUS_DB = os.environ.get("MILVUS_DB", "default")
 MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
-
 MILVUS_INDEX_TYPE = os.environ.get("MILVUS_INDEX_TYPE", "HNSW")
 MILVUS_METRIC_TYPE = os.environ.get("MILVUS_METRIC_TYPE", "COSINE")
 MILVUS_HNSW_M = int(os.environ.get("MILVUS_HNSW_M", "16"))
@@ -2019,6 +2017,11 @@ MILVUS_DISKANN_MAX_DEGREE = int(os.environ.get("MILVUS_DISKANN_MAX_DEGREE", "56"
 MILVUS_DISKANN_SEARCH_LIST_SIZE = int(
     os.environ.get("MILVUS_DISKANN_SEARCH_LIST_SIZE", "100")
 )
+ENABLE_MILVUS_MULTITENANCY_MODE = (
+    os.environ.get("ENABLE_MILVUS_MULTITENANCY_MODE", "true").lower() == "true"
+)
+# Hyphens not allowed, need to use underscores in collection names
+MILVUS_COLLECTION_PREFIX = os.environ.get("MILVUS_COLLECTION_PREFIX", "open_webui")
 
 # Qdrant
 QDRANT_URI = os.environ.get("QDRANT_URI", None)

+ 288 - 0
backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py

@@ -0,0 +1,288 @@
+import logging
+from typing import Optional, Tuple, List, Dict, Any
+
+from open_webui.config import (
+    MILVUS_URI,
+    MILVUS_TOKEN,
+    MILVUS_DB,
+    MILVUS_COLLECTION_PREFIX,
+    MILVUS_INDEX_TYPE,
+    MILVUS_METRIC_TYPE,
+    MILVUS_HNSW_M,
+    MILVUS_HNSW_EFCONSTRUCTION,
+    MILVUS_IVF_FLAT_NLIST,
+)
+from open_webui.env import SRC_LOG_LEVELS
+from open_webui.retrieval.vector.main import (
+    GetResult,
+    SearchResult,
+    VectorDBBase,
+    VectorItem,
+)
+from pymilvus import (
+    connections,
+    utility,
+    Collection,
+    CollectionSchema,
+    FieldSchema,
+    DataType,
+)
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+RESOURCE_ID_FIELD = "resource_id"
+
+
+class MilvusClient(VectorDBBase):
+    def __init__(self):
+        # Milvus collection names can only contain numbers, letters, and underscores.
+        self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
+        connections.connect(
+            alias="default",
+            uri=MILVUS_URI,
+            token=MILVUS_TOKEN,
+            db_name=MILVUS_DB,
+        )
+
+        # Main collection types for multi-tenancy
+        self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
+        self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
+        self.FILE_COLLECTION = f"{self.collection_prefix}_files"
+        self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
+        self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
+        self.shared_collections = [
+            self.MEMORY_COLLECTION,
+            self.KNOWLEDGE_COLLECTION,
+            self.FILE_COLLECTION,
+            self.WEB_SEARCH_COLLECTION,
+            self.HASH_BASED_COLLECTION,
+        ]
+
+    def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
+        """
+        Maps the traditional collection name to multi-tenant collection and resource ID.
+        
+        WARNING: This mapping relies on current Open WebUI naming conventions for 
+        collection names. If Open WebUI changes how it generates collection names
+        (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash 
+        formats), this mapping will break and route data to incorrect collections.
+        POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
+        DATA MAPPING INSIDE THE DATABASE.
+        """
+        resource_id = collection_name
+
+        if collection_name.startswith("user-memory-"):
+            return self.MEMORY_COLLECTION, resource_id
+        elif collection_name.startswith("file-"):
+            return self.FILE_COLLECTION, resource_id
+        elif collection_name.startswith("web-search-"):
+            return self.WEB_SEARCH_COLLECTION, resource_id
+        elif len(collection_name) == 63 and all(
+            c in "0123456789abcdef" for c in collection_name
+        ):
+            return self.HASH_BASED_COLLECTION, resource_id
+        else:
+            return self.KNOWLEDGE_COLLECTION, resource_id
+
+    def _create_shared_collection(self, mt_collection_name: str, dimension: int):
+        fields = [
+            FieldSchema(
+                name="id",
+                dtype=DataType.VARCHAR,
+                is_primary=True,
+                auto_id=False,
+                max_length=36,
+            ),
+            FieldSchema(
+                name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
+            ),
+            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
+            FieldSchema(name="metadata", dtype=DataType.JSON),
+            FieldSchema(
+                name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255
+            ),
+        ]
+        schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
+        collection = Collection(mt_collection_name, schema)
+
+        index_params = {
+            "metric_type": MILVUS_METRIC_TYPE,
+            "index_type": MILVUS_INDEX_TYPE,
+            "params": {},
+        }
+        if MILVUS_INDEX_TYPE == "HNSW":
+            index_params["params"] = {
+                "M": MILVUS_HNSW_M,
+                "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
+            }
+        elif MILVUS_INDEX_TYPE == "IVF_FLAT":
+            index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
+
+        collection.create_index("vector", index_params)
+        collection.create_index(RESOURCE_ID_FIELD)
+        log.info(f"Created shared collection: {mt_collection_name}")
+        return collection
+
+    def _ensure_collection(self, mt_collection_name: str, dimension: int):
+        if not utility.has_collection(mt_collection_name):
+            self._create_shared_collection(mt_collection_name, dimension)
+
+    def has_collection(self, collection_name: str) -> bool:
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return False
+
+        collection = Collection(mt_collection)
+        collection.load()
+        res = collection.query(
+            expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1
+        )
+        return len(res) > 0
+
+    def upsert(self, collection_name: str, items: List[VectorItem]):
+        if not items:
+            return
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        dimension = len(items[0]["vector"])
+        self._ensure_collection(mt_collection, dimension)
+        collection = Collection(mt_collection)
+
+        entities = [
+            {
+                "id": item["id"],
+                "vector": item["vector"],
+                "text": item["text"],
+                "metadata": item["metadata"],
+                RESOURCE_ID_FIELD: resource_id,
+            }
+            for item in items
+        ]
+        collection.insert(entities)
+        collection.flush()
+
+    def search(
+        self, collection_name: str, vectors: List[List[float]], limit: int
+    ) -> Optional[SearchResult]:
+        if not vectors:
+            return None
+
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return None
+
+        collection = Collection(mt_collection)
+        collection.load()
+
+        search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
+        results = collection.search(
+            data=vectors,
+            anns_field="vector",
+            param=search_params,
+            limit=limit,
+            expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
+            output_fields=["id", "text", "metadata"],
+        )
+
+        ids, documents, metadatas, distances = [], [], [], []
+        for hits in results:
+            batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
+            for hit in hits:
+                batch_ids.append(hit.entity.get("id"))
+                batch_docs.append(hit.entity.get("text"))
+                batch_metadatas.append(hit.entity.get("metadata"))
+                batch_dists.append(hit.distance)
+            ids.append(batch_ids)
+            documents.append(batch_docs)
+            metadatas.append(batch_metadatas)
+            distances.append(batch_dists)
+
+        return SearchResult(
+            ids=ids, documents=documents, metadatas=metadatas, distances=distances
+        )
+
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict[str, Any]] = None,
+    ):
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return
+
+        collection = Collection(mt_collection)
+        
+        # Build expression
+        expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
+        if ids:
+            # Milvus expects a string list for 'in' operator
+            id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
+            expr.append(f"id in [{id_list_str}]")
+        
+        if filter:
+            for key, value in filter.items():
+                 expr.append(f"metadata['{key}'] == '{value}'")
+        
+        collection.delete(" and ".join(expr))
+
+    def reset(self):
+        for collection_name in self.shared_collections:
+            if utility.has_collection(collection_name):
+                utility.drop_collection(collection_name)
+
+    def delete_collection(self, collection_name: str):
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return
+        
+        collection = Collection(mt_collection)
+        collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
+
+    def query(
+        self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return None
+
+        collection = Collection(mt_collection)
+        collection.load()
+
+        expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
+        if filter:
+            for key, value in filter.items():
+                if isinstance(value, str):
+                    expr.append(f"metadata['{key}'] == '{value}'")
+                else:
+                    expr.append(f"metadata['{key}'] == {value}")
+
+        results = collection.query(
+            expr=" and ".join(expr),
+            output_fields=["id", "text", "metadata"],
+            limit=limit,
+        )
+
+        ids = [res["id"] for res in results]
+        documents = [res["text"] for res in results]
+        metadatas = [res["metadata"] for res in results]
+
+        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
+
+    def get(self, collection_name: str) -> Optional[GetResult]:
+        return self.query(collection_name, filter={}, limit=None)
+
+    def insert(self, collection_name: str, items: List[VectorItem]):
+        return self.upsert(collection_name, items)

+ 7 - 0
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py

@@ -105,6 +105,13 @@ class QdrantClient(VectorDBBase):
 
         Returns:
             tuple: (collection_name, tenant_id)
+
+        WARNING: This mapping relies on current Open WebUI naming conventions for 
+        collection names. If Open WebUI changes how it generates collection names
+        (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash 
+        formats), this mapping will break and route data to incorrect collections.
+        POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
+        DATA MAPPING INSIDE THE DATABASE.
         """
         # Check for user memory collections
         tenant_id = collection_name

+ 14 - 3
backend/open_webui/retrieval/vector/factory.py

@@ -1,6 +1,10 @@
 from open_webui.retrieval.vector.main import VectorDBBase
 from open_webui.retrieval.vector.type import VectorType
-from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
+from open_webui.config import (
+    VECTOR_DB,
+    ENABLE_QDRANT_MULTITENANCY_MODE,
+    ENABLE_MILVUS_MULTITENANCY_MODE,
+)
 
 
 class Vector:
@@ -12,9 +16,16 @@ class Vector:
         """
         match vector_type:
             case VectorType.MILVUS:
-                from open_webui.retrieval.vector.dbs.milvus import MilvusClient
+                if ENABLE_MILVUS_MULTITENANCY_MODE:
+                    from open_webui.retrieval.vector.dbs.milvus_multitenancy import (
+                        MilvusClient,
+                    )
 
-                return MilvusClient()
+                    return MilvusClient()
+                else:
+                    from open_webui.retrieval.vector.dbs.milvus import MilvusClient
+    
+                    return MilvusClient()
             case VectorType.QDRANT:
                 if ENABLE_QDRANT_MULTITENANCY_MODE:
                     from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (