Selaa lähdekoodia

Create milvus_multitenancy.py

Classic298 1 viikko sitten
vanhempi
commit
8dc43f9e3a
1 muutettua tiedostoa jossa 281 lisäystä ja 0 poistoa
  1. 281 0
      backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py

+ 281 - 0
backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py

@@ -0,0 +1,281 @@
+import logging
+from typing import Optional, Tuple, List, Dict, Any
+
+from open_webui.config import (
+    MILVUS_URI,
+    MILVUS_TOKEN,
+    MILVUS_DB,
+    MILVUS_COLLECTION_PREFIX,
+    MILVUS_INDEX_TYPE,
+    MILVUS_METRIC_TYPE,
+    MILVUS_HNSW_M,
+    MILVUS_HNSW_EFCONSTRUCTION,
+    MILVUS_IVF_FLAT_NLIST,
+)
+from open_webui.env import SRC_LOG_LEVELS
+from open_webui.retrieval.vector.main import (
+    GetResult,
+    SearchResult,
+    VectorDBBase,
+    VectorItem,
+)
+from pymilvus import (
+    connections,
+    utility,
+    Collection,
+    CollectionSchema,
+    FieldSchema,
+    DataType,
+)
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+RESOURCE_ID_FIELD = "resource_id"
+
+
+class MilvusClient(VectorDBBase):
+    def __init__(self):
+        # Milvus collection names can only contain numbers, letters, and underscores.
+        self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
+        connections.connect(
+            alias="default",
+            uri=MILVUS_URI,
+            token=MILVUS_TOKEN,
+            db_name=MILVUS_DB,
+        )
+
+        # Main collection types for multi-tenancy
+        self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
+        self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
+        self.FILE_COLLECTION = f"{self.collection_prefix}_files"
+        self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
+        self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
+        self.shared_collections = [
+            self.MEMORY_COLLECTION,
+            self.KNOWLEDGE_COLLECTION,
+            self.FILE_COLLECTION,
+            self.WEB_SEARCH_COLLECTION,
+            self.HASH_BASED_COLLECTION,
+        ]
+
+    def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
+        """
+        Maps the traditional collection name to multi-tenant collection and resource ID.
+        """
+        resource_id = collection_name
+
+        if collection_name.startswith("user-memory-"):
+            return self.MEMORY_COLLECTION, resource_id
+        elif collection_name.startswith("file-"):
+            return self.FILE_COLLECTION, resource_id
+        elif collection_name.startswith("web-search-"):
+            return self.WEB_SEARCH_COLLECTION, resource_id
+        elif len(collection_name) == 63 and all(
+            c in "0123456789abcdef" for c in collection_name
+        ):
+            return self.HASH_BASED_COLLECTION, resource_id
+        else:
+            return self.KNOWLEDGE_COLLECTION, resource_id
+
+    def _create_shared_collection(self, mt_collection_name: str, dimension: int):
+        fields = [
+            FieldSchema(
+                name="id",
+                dtype=DataType.VARCHAR,
+                is_primary=True,
+                auto_id=False,
+                max_length=36,
+            ),
+            FieldSchema(
+                name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
+            ),
+            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
+            FieldSchema(name="metadata", dtype=DataType.JSON),
+            FieldSchema(
+                name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255
+            ),
+        ]
+        schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
+        collection = Collection(mt_collection_name, schema)
+
+        index_params = {
+            "metric_type": MILVUS_METRIC_TYPE,
+            "index_type": MILVUS_INDEX_TYPE,
+            "params": {},
+        }
+        if MILVUS_INDEX_TYPE == "HNSW":
+            index_params["params"] = {
+                "M": MILVUS_HNSW_M,
+                "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
+            }
+        elif MILVUS_INDEX_TYPE == "IVF_FLAT":
+            index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
+
+        collection.create_index("vector", index_params)
+        collection.create_index(RESOURCE_ID_FIELD)
+        log.info(f"Created shared collection: {mt_collection_name}")
+        return collection
+
+    def _ensure_collection(self, mt_collection_name: str, dimension: int):
+        if not utility.has_collection(mt_collection_name):
+            self._create_shared_collection(mt_collection_name, dimension)
+
+    def has_collection(self, collection_name: str) -> bool:
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return False
+
+        collection = Collection(mt_collection)
+        collection.load()
+        res = collection.query(
+            expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1
+        )
+        return len(res) > 0
+
+    def upsert(self, collection_name: str, items: List[VectorItem]):
+        if not items:
+            return
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        dimension = len(items[0]["vector"])
+        self._ensure_collection(mt_collection, dimension)
+        collection = Collection(mt_collection)
+
+        entities = [
+            {
+                "id": item["id"],
+                "vector": item["vector"],
+                "text": item["text"],
+                "metadata": item["metadata"],
+                RESOURCE_ID_FIELD: resource_id,
+            }
+            for item in items
+        ]
+        collection.insert(entities)
+        collection.flush()
+
+    def search(
+        self, collection_name: str, vectors: List[List[float]], limit: int
+    ) -> Optional[SearchResult]:
+        if not vectors:
+            return None
+
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return None
+
+        collection = Collection(mt_collection)
+        collection.load()
+
+        search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
+        results = collection.search(
+            data=vectors,
+            anns_field="vector",
+            param=search_params,
+            limit=limit,
+            expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
+            output_fields=["id", "text", "metadata"],
+        )
+
+        ids, documents, metadatas, distances = [], [], [], []
+        for hits in results:
+            batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
+            for hit in hits:
+                batch_ids.append(hit.entity.get("id"))
+                batch_docs.append(hit.entity.get("text"))
+                batch_metadatas.append(hit.entity.get("metadata"))
+                batch_dists.append(hit.distance)
+            ids.append(batch_ids)
+            documents.append(batch_docs)
+            metadatas.append(batch_metadatas)
+            distances.append(batch_dists)
+
+        return SearchResult(
+            ids=ids, documents=documents, metadatas=metadatas, distances=distances
+        )
+
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict[str, Any]] = None,
+    ):
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return
+
+        collection = Collection(mt_collection)
+        
+        # Build expression
+        expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
+        if ids:
+            # Milvus expects a string list for 'in' operator
+            id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
+            expr.append(f"id in [{id_list_str}]")
+        
+        if filter:
+            for key, value in filter.items():
+                 expr.append(f"metadata['{key}'] == '{value}'")
+        
+        collection.delete(" and ".join(expr))
+
+    def reset(self):
+        for collection_name in self.shared_collections:
+            if utility.has_collection(collection_name):
+                utility.drop_collection(collection_name)
+
+    def delete_collection(self, collection_name: str):
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return
+        
+        collection = Collection(mt_collection)
+        collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
+
+    def query(
+        self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        mt_collection, resource_id = self._get_collection_and_resource_id(
+            collection_name
+        )
+        if not utility.has_collection(mt_collection):
+            return None
+
+        collection = Collection(mt_collection)
+        collection.load()
+
+        expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
+        if filter:
+            for key, value in filter.items():
+                if isinstance(value, str):
+                    expr.append(f"metadata['{key}'] == '{value}'")
+                else:
+                    expr.append(f"metadata['{key}'] == {value}")
+
+        results = collection.query(
+            expr=" and ".join(expr),
+            output_fields=["id", "text", "metadata"],
+            limit=limit,
+        )
+
+        ids = [res["id"] for res in results]
+        documents = [res["text"] for res in results]
+        metadatas = [res["metadata"] for res in results]
+
+        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
+
+    def get(self, collection_name: str) -> Optional[GetResult]:
+        return self.query(collection_name, filter={}, limit=None)
+
+    def insert(self, collection_name: str, items: List[VectorItem]):
+        return self.upsert(collection_name, items)