import logging from typing import Optional, Tuple, List, Dict, Any from open_webui.config import ( MILVUS_URI, MILVUS_TOKEN, MILVUS_DB, MILVUS_COLLECTION_PREFIX, MILVUS_INDEX_TYPE, MILVUS_METRIC_TYPE, MILVUS_HNSW_M, MILVUS_HNSW_EFCONSTRUCTION, MILVUS_IVF_FLAT_NLIST, ) from open_webui.env import SRC_LOG_LEVELS from open_webui.retrieval.vector.main import ( GetResult, SearchResult, VectorDBBase, VectorItem, ) from pymilvus import ( connections, utility, Collection, CollectionSchema, FieldSchema, DataType, ) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) RESOURCE_ID_FIELD = "resource_id" class MilvusClient(VectorDBBase): def __init__(self): # Milvus collection names can only contain numbers, letters, and underscores. self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_") connections.connect( alias="default", uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB, ) # Main collection types for multi-tenancy self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" self.FILE_COLLECTION = f"{self.collection_prefix}_files" self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search" self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based" self.shared_collections = [ self.MEMORY_COLLECTION, self.KNOWLEDGE_COLLECTION, self.FILE_COLLECTION, self.WEB_SEARCH_COLLECTION, self.HASH_BASED_COLLECTION, ] def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]: """ Maps the traditional collection name to multi-tenant collection and resource ID. """ resource_id = collection_name if collection_name.startswith("user-memory-"): return self.MEMORY_COLLECTION, resource_id elif collection_name.startswith("file-"): return self.FILE_COLLECTION, resource_id elif collection_name.startswith("web-search-"): return self.WEB_SEARCH_COLLECTION, resource_id elif len(collection_name) == 63 and all( c in "0123456789abcdef" for c in collection_name ): return self.HASH_BASED_COLLECTION, resource_id else: return self.KNOWLEDGE_COLLECTION, resource_id def _create_shared_collection(self, mt_collection_name: str, dimension: int): fields = [ FieldSchema( name="id", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=36, ), FieldSchema( name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension ), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="metadata", dtype=DataType.JSON), FieldSchema( name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255 ), ] schema = CollectionSchema(fields, "Shared collection for multi-tenancy") collection = Collection(mt_collection_name, schema) index_params = { "metric_type": MILVUS_METRIC_TYPE, "index_type": MILVUS_INDEX_TYPE, "params": {}, } if MILVUS_INDEX_TYPE == "HNSW": index_params["params"] = { "M": MILVUS_HNSW_M, "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, } elif MILVUS_INDEX_TYPE == "IVF_FLAT": index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST} collection.create_index("vector", index_params) collection.create_index(RESOURCE_ID_FIELD) log.info(f"Created shared collection: {mt_collection_name}") return collection def _ensure_collection(self, mt_collection_name: str, dimension: int): if not utility.has_collection(mt_collection_name): self._create_shared_collection(mt_collection_name, dimension) def has_collection(self, collection_name: str) -> bool: mt_collection, resource_id = self._get_collection_and_resource_id( collection_name ) if not utility.has_collection(mt_collection): return False collection = Collection(mt_collection) collection.load() res = collection.query( expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1 ) return len(res) > 0 def upsert(self, collection_name: str, items: List[VectorItem]): if not items: return mt_collection, resource_id = self._get_collection_and_resource_id( collection_name ) dimension = len(items[0]["vector"]) self._ensure_collection(mt_collection, dimension) collection = Collection(mt_collection) entities = [ { "id": item["id"], "vector": item["vector"], "text": item["text"], "metadata": item["metadata"], RESOURCE_ID_FIELD: resource_id, } for item in items ] collection.insert(entities) collection.flush() def search( self, collection_name: str, vectors: List[List[float]], limit: int ) -> Optional[SearchResult]: if not vectors: return None mt_collection, resource_id = self._get_collection_and_resource_id( collection_name ) if not utility.has_collection(mt_collection): return None collection = Collection(mt_collection) collection.load() search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}} results = collection.search( data=vectors, anns_field="vector", param=search_params, limit=limit, expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", output_fields=["id", "text", "metadata"], ) ids, documents, metadatas, distances = [], [], [], [] for hits in results: batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], [] for hit in hits: batch_ids.append(hit.entity.get("id")) batch_docs.append(hit.entity.get("text")) batch_metadatas.append(hit.entity.get("metadata")) batch_dists.append(hit.distance) ids.append(batch_ids) documents.append(batch_docs) metadatas.append(batch_metadatas) distances.append(batch_dists) return SearchResult( ids=ids, documents=documents, metadatas=metadatas, distances=distances ) def delete( self, collection_name: str, ids: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, ): mt_collection, resource_id = self._get_collection_and_resource_id( collection_name ) if not utility.has_collection(mt_collection): return collection = Collection(mt_collection) # Build expression expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] if ids: # Milvus expects a string list for 'in' operator id_list_str = ", ".join([f"'{id_val}'" for id_val in ids]) expr.append(f"id in [{id_list_str}]") if filter: for key, value in filter.items(): expr.append(f"metadata['{key}'] == '{value}'") collection.delete(" and ".join(expr)) def reset(self): for collection_name in self.shared_collections: if utility.has_collection(collection_name): utility.drop_collection(collection_name) def delete_collection(self, collection_name: str): mt_collection, resource_id = self._get_collection_and_resource_id( collection_name ) if not utility.has_collection(mt_collection): return collection = Collection(mt_collection) collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'") def query( self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None ) -> Optional[GetResult]: mt_collection, resource_id = self._get_collection_and_resource_id( collection_name ) if not utility.has_collection(mt_collection): return None collection = Collection(mt_collection) collection.load() expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] if filter: for key, value in filter.items(): if isinstance(value, str): expr.append(f"metadata['{key}'] == '{value}'") else: expr.append(f"metadata['{key}'] == {value}") results = collection.query( expr=" and ".join(expr), output_fields=["id", "text", "metadata"], limit=limit, ) ids = [res["id"] for res in results] documents = [res["text"] for res in results] metadatas = [res["metadata"] for res in results] return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) def get(self, collection_name: str) -> Optional[GetResult]: return self.query(collection_name, filter={}, limit=None) def insert(self, collection_name: str, items: List[VectorItem]): return self.upsert(collection_name, items)