|
@@ -0,0 +1,288 @@
|
|
|
+import logging
|
|
|
+from typing import Optional, Tuple, List, Dict, Any
|
|
|
+
|
|
|
+from open_webui.config import (
|
|
|
+ MILVUS_URI,
|
|
|
+ MILVUS_TOKEN,
|
|
|
+ MILVUS_DB,
|
|
|
+ MILVUS_COLLECTION_PREFIX,
|
|
|
+ MILVUS_INDEX_TYPE,
|
|
|
+ MILVUS_METRIC_TYPE,
|
|
|
+ MILVUS_HNSW_M,
|
|
|
+ MILVUS_HNSW_EFCONSTRUCTION,
|
|
|
+ MILVUS_IVF_FLAT_NLIST,
|
|
|
+)
|
|
|
+from open_webui.env import SRC_LOG_LEVELS
|
|
|
+from open_webui.retrieval.vector.main import (
|
|
|
+ GetResult,
|
|
|
+ SearchResult,
|
|
|
+ VectorDBBase,
|
|
|
+ VectorItem,
|
|
|
+)
|
|
|
+from pymilvus import (
|
|
|
+ connections,
|
|
|
+ utility,
|
|
|
+ Collection,
|
|
|
+ CollectionSchema,
|
|
|
+ FieldSchema,
|
|
|
+ DataType,
|
|
|
+)
|
|
|
+
|
|
|
+log = logging.getLogger(__name__)
|
|
|
+log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|
|
+
|
|
|
+RESOURCE_ID_FIELD = "resource_id"
|
|
|
+
|
|
|
+
|
|
|
+class MilvusClient(VectorDBBase):
|
|
|
+ def __init__(self):
|
|
|
+ # Milvus collection names can only contain numbers, letters, and underscores.
|
|
|
+ self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
|
|
|
+ connections.connect(
|
|
|
+ alias="default",
|
|
|
+ uri=MILVUS_URI,
|
|
|
+ token=MILVUS_TOKEN,
|
|
|
+ db_name=MILVUS_DB,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Main collection types for multi-tenancy
|
|
|
+ self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
|
|
+ self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
|
|
+ self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
|
|
+ self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
|
|
|
+ self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
|
|
|
+ self.shared_collections = [
|
|
|
+ self.MEMORY_COLLECTION,
|
|
|
+ self.KNOWLEDGE_COLLECTION,
|
|
|
+ self.FILE_COLLECTION,
|
|
|
+ self.WEB_SEARCH_COLLECTION,
|
|
|
+ self.HASH_BASED_COLLECTION,
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
|
|
|
+ """
|
|
|
+ Maps the traditional collection name to multi-tenant collection and resource ID.
|
|
|
+
|
|
|
+ WARNING: This mapping relies on current Open WebUI naming conventions for
|
|
|
+ collection names. If Open WebUI changes how it generates collection names
|
|
|
+ (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
|
|
|
+ formats), this mapping will break and route data to incorrect collections.
|
|
|
+ POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
|
|
|
+ DATA MAPPING INSIDE THE DATABASE.
|
|
|
+ """
|
|
|
+ resource_id = collection_name
|
|
|
+
|
|
|
+ if collection_name.startswith("user-memory-"):
|
|
|
+ return self.MEMORY_COLLECTION, resource_id
|
|
|
+ elif collection_name.startswith("file-"):
|
|
|
+ return self.FILE_COLLECTION, resource_id
|
|
|
+ elif collection_name.startswith("web-search-"):
|
|
|
+ return self.WEB_SEARCH_COLLECTION, resource_id
|
|
|
+ elif len(collection_name) == 63 and all(
|
|
|
+ c in "0123456789abcdef" for c in collection_name
|
|
|
+ ):
|
|
|
+ return self.HASH_BASED_COLLECTION, resource_id
|
|
|
+ else:
|
|
|
+ return self.KNOWLEDGE_COLLECTION, resource_id
|
|
|
+
|
|
|
+ def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
|
|
+ fields = [
|
|
|
+ FieldSchema(
|
|
|
+ name="id",
|
|
|
+ dtype=DataType.VARCHAR,
|
|
|
+ is_primary=True,
|
|
|
+ auto_id=False,
|
|
|
+ max_length=36,
|
|
|
+ ),
|
|
|
+ FieldSchema(
|
|
|
+ name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
|
|
|
+ ),
|
|
|
+ FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
|
|
+ FieldSchema(name="metadata", dtype=DataType.JSON),
|
|
|
+ FieldSchema(
|
|
|
+ name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255
|
|
|
+ ),
|
|
|
+ ]
|
|
|
+ schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
|
|
|
+ collection = Collection(mt_collection_name, schema)
|
|
|
+
|
|
|
+ index_params = {
|
|
|
+ "metric_type": MILVUS_METRIC_TYPE,
|
|
|
+ "index_type": MILVUS_INDEX_TYPE,
|
|
|
+ "params": {},
|
|
|
+ }
|
|
|
+ if MILVUS_INDEX_TYPE == "HNSW":
|
|
|
+ index_params["params"] = {
|
|
|
+ "M": MILVUS_HNSW_M,
|
|
|
+ "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
|
|
+ }
|
|
|
+ elif MILVUS_INDEX_TYPE == "IVF_FLAT":
|
|
|
+ index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
|
|
+
|
|
|
+ collection.create_index("vector", index_params)
|
|
|
+ collection.create_index(RESOURCE_ID_FIELD)
|
|
|
+ log.info(f"Created shared collection: {mt_collection_name}")
|
|
|
+ return collection
|
|
|
+
|
|
|
+ def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
|
|
+ if not utility.has_collection(mt_collection_name):
|
|
|
+ self._create_shared_collection(mt_collection_name, dimension)
|
|
|
+
|
|
|
+ def has_collection(self, collection_name: str) -> bool:
|
|
|
+ mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
|
+ collection_name
|
|
|
+ )
|
|
|
+ if not utility.has_collection(mt_collection):
|
|
|
+ return False
|
|
|
+
|
|
|
+ collection = Collection(mt_collection)
|
|
|
+ collection.load()
|
|
|
+ res = collection.query(
|
|
|
+ expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1
|
|
|
+ )
|
|
|
+ return len(res) > 0
|
|
|
+
|
|
|
+ def upsert(self, collection_name: str, items: List[VectorItem]):
|
|
|
+ if not items:
|
|
|
+ return
|
|
|
+ mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
|
+ collection_name
|
|
|
+ )
|
|
|
+ dimension = len(items[0]["vector"])
|
|
|
+ self._ensure_collection(mt_collection, dimension)
|
|
|
+ collection = Collection(mt_collection)
|
|
|
+
|
|
|
+ entities = [
|
|
|
+ {
|
|
|
+ "id": item["id"],
|
|
|
+ "vector": item["vector"],
|
|
|
+ "text": item["text"],
|
|
|
+ "metadata": item["metadata"],
|
|
|
+ RESOURCE_ID_FIELD: resource_id,
|
|
|
+ }
|
|
|
+ for item in items
|
|
|
+ ]
|
|
|
+ collection.insert(entities)
|
|
|
+ collection.flush()
|
|
|
+
|
|
|
+ def search(
|
|
|
+ self, collection_name: str, vectors: List[List[float]], limit: int
|
|
|
+ ) -> Optional[SearchResult]:
|
|
|
+ if not vectors:
|
|
|
+ return None
|
|
|
+
|
|
|
+ mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
|
+ collection_name
|
|
|
+ )
|
|
|
+ if not utility.has_collection(mt_collection):
|
|
|
+ return None
|
|
|
+
|
|
|
+ collection = Collection(mt_collection)
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+ search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
|
|
|
+ results = collection.search(
|
|
|
+ data=vectors,
|
|
|
+ anns_field="vector",
|
|
|
+ param=search_params,
|
|
|
+ limit=limit,
|
|
|
+ expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
|
|
+ output_fields=["id", "text", "metadata"],
|
|
|
+ )
|
|
|
+
|
|
|
+ ids, documents, metadatas, distances = [], [], [], []
|
|
|
+ for hits in results:
|
|
|
+ batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
|
|
+ for hit in hits:
|
|
|
+ batch_ids.append(hit.entity.get("id"))
|
|
|
+ batch_docs.append(hit.entity.get("text"))
|
|
|
+ batch_metadatas.append(hit.entity.get("metadata"))
|
|
|
+ batch_dists.append(hit.distance)
|
|
|
+ ids.append(batch_ids)
|
|
|
+ documents.append(batch_docs)
|
|
|
+ metadatas.append(batch_metadatas)
|
|
|
+ distances.append(batch_dists)
|
|
|
+
|
|
|
+ return SearchResult(
|
|
|
+ ids=ids, documents=documents, metadatas=metadatas, distances=distances
|
|
|
+ )
|
|
|
+
|
|
|
+ def delete(
|
|
|
+ self,
|
|
|
+ collection_name: str,
|
|
|
+ ids: Optional[List[str]] = None,
|
|
|
+ filter: Optional[Dict[str, Any]] = None,
|
|
|
+ ):
|
|
|
+ mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
|
+ collection_name
|
|
|
+ )
|
|
|
+ if not utility.has_collection(mt_collection):
|
|
|
+ return
|
|
|
+
|
|
|
+ collection = Collection(mt_collection)
|
|
|
+
|
|
|
+ # Build expression
|
|
|
+ expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
|
|
+ if ids:
|
|
|
+ # Milvus expects a string list for 'in' operator
|
|
|
+ id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
|
|
|
+ expr.append(f"id in [{id_list_str}]")
|
|
|
+
|
|
|
+ if filter:
|
|
|
+ for key, value in filter.items():
|
|
|
+ expr.append(f"metadata['{key}'] == '{value}'")
|
|
|
+
|
|
|
+ collection.delete(" and ".join(expr))
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ for collection_name in self.shared_collections:
|
|
|
+ if utility.has_collection(collection_name):
|
|
|
+ utility.drop_collection(collection_name)
|
|
|
+
|
|
|
+ def delete_collection(self, collection_name: str):
|
|
|
+ mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
|
+ collection_name
|
|
|
+ )
|
|
|
+ if not utility.has_collection(mt_collection):
|
|
|
+ return
|
|
|
+
|
|
|
+ collection = Collection(mt_collection)
|
|
|
+ collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
|
|
+
|
|
|
+ def query(
|
|
|
+ self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
|
|
+ ) -> Optional[GetResult]:
|
|
|
+ mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
|
+ collection_name
|
|
|
+ )
|
|
|
+ if not utility.has_collection(mt_collection):
|
|
|
+ return None
|
|
|
+
|
|
|
+ collection = Collection(mt_collection)
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+ expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
|
|
+ if filter:
|
|
|
+ for key, value in filter.items():
|
|
|
+ if isinstance(value, str):
|
|
|
+ expr.append(f"metadata['{key}'] == '{value}'")
|
|
|
+ else:
|
|
|
+ expr.append(f"metadata['{key}'] == {value}")
|
|
|
+
|
|
|
+ results = collection.query(
|
|
|
+ expr=" and ".join(expr),
|
|
|
+ output_fields=["id", "text", "metadata"],
|
|
|
+ limit=limit,
|
|
|
+ )
|
|
|
+
|
|
|
+ ids = [res["id"] for res in results]
|
|
|
+ documents = [res["text"] for res in results]
|
|
|
+ metadatas = [res["metadata"] for res in results]
|
|
|
+
|
|
|
+ return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
|
|
+
|
|
|
+ def get(self, collection_name: str) -> Optional[GetResult]:
|
|
|
+ return self.query(collection_name, filter={}, limit=None)
|
|
|
+
|
|
|
+ def insert(self, collection_name: str, items: List[VectorItem]):
|
|
|
+ return self.upsert(collection_name, items)
|