milvus_multitenancy.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import logging
  2. from typing import Optional, Tuple, List, Dict, Any
  3. from open_webui.config import (
  4. MILVUS_URI,
  5. MILVUS_TOKEN,
  6. MILVUS_DB,
  7. MILVUS_COLLECTION_PREFIX,
  8. MILVUS_INDEX_TYPE,
  9. MILVUS_METRIC_TYPE,
  10. MILVUS_HNSW_M,
  11. MILVUS_HNSW_EFCONSTRUCTION,
  12. MILVUS_IVF_FLAT_NLIST,
  13. )
  14. from open_webui.env import SRC_LOG_LEVELS
  15. from open_webui.retrieval.vector.main import (
  16. GetResult,
  17. SearchResult,
  18. VectorDBBase,
  19. VectorItem,
  20. )
  21. from pymilvus import (
  22. connections,
  23. utility,
  24. Collection,
  25. CollectionSchema,
  26. FieldSchema,
  27. DataType,
  28. )
  29. log = logging.getLogger(__name__)
  30. log.setLevel(SRC_LOG_LEVELS["RAG"])
  31. RESOURCE_ID_FIELD = "resource_id"
  32. class MilvusClient(VectorDBBase):
  33. def __init__(self):
  34. # Milvus collection names can only contain numbers, letters, and underscores.
  35. self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
  36. connections.connect(
  37. alias="default",
  38. uri=MILVUS_URI,
  39. token=MILVUS_TOKEN,
  40. db_name=MILVUS_DB,
  41. )
  42. # Main collection types for multi-tenancy
  43. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  44. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  45. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  46. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
  47. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
  48. self.shared_collections = [
  49. self.MEMORY_COLLECTION,
  50. self.KNOWLEDGE_COLLECTION,
  51. self.FILE_COLLECTION,
  52. self.WEB_SEARCH_COLLECTION,
  53. self.HASH_BASED_COLLECTION,
  54. ]
  55. def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
  56. """
  57. Maps the traditional collection name to multi-tenant collection and resource ID.
  58. WARNING: This mapping relies on current Open WebUI naming conventions for
  59. collection names. If Open WebUI changes how it generates collection names
  60. (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
  61. formats), this mapping will break and route data to incorrect collections.
  62. POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
  63. DATA MAPPING INSIDE THE DATABASE.
  64. """
  65. resource_id = collection_name
  66. if collection_name.startswith("user-memory-"):
  67. return self.MEMORY_COLLECTION, resource_id
  68. elif collection_name.startswith("file-"):
  69. return self.FILE_COLLECTION, resource_id
  70. elif collection_name.startswith("web-search-"):
  71. return self.WEB_SEARCH_COLLECTION, resource_id
  72. elif len(collection_name) == 63 and all(
  73. c in "0123456789abcdef" for c in collection_name
  74. ):
  75. return self.HASH_BASED_COLLECTION, resource_id
  76. else:
  77. return self.KNOWLEDGE_COLLECTION, resource_id
  78. def _create_shared_collection(self, mt_collection_name: str, dimension: int):
  79. fields = [
  80. FieldSchema(
  81. name="id",
  82. dtype=DataType.VARCHAR,
  83. is_primary=True,
  84. auto_id=False,
  85. max_length=36,
  86. ),
  87. FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  88. FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
  89. FieldSchema(name="metadata", dtype=DataType.JSON),
  90. FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
  91. ]
  92. schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
  93. collection = Collection(mt_collection_name, schema)
  94. index_params = {
  95. "metric_type": MILVUS_METRIC_TYPE,
  96. "index_type": MILVUS_INDEX_TYPE,
  97. "params": {},
  98. }
  99. if MILVUS_INDEX_TYPE == "HNSW":
  100. index_params["params"] = {
  101. "M": MILVUS_HNSW_M,
  102. "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
  103. }
  104. elif MILVUS_INDEX_TYPE == "IVF_FLAT":
  105. index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
  106. collection.create_index("vector", index_params)
  107. collection.create_index(RESOURCE_ID_FIELD)
  108. log.info(f"Created shared collection: {mt_collection_name}")
  109. return collection
  110. def _ensure_collection(self, mt_collection_name: str, dimension: int):
  111. if not utility.has_collection(mt_collection_name):
  112. self._create_shared_collection(mt_collection_name, dimension)
  113. def has_collection(self, collection_name: str) -> bool:
  114. mt_collection, resource_id = self._get_collection_and_resource_id(
  115. collection_name
  116. )
  117. if not utility.has_collection(mt_collection):
  118. return False
  119. collection = Collection(mt_collection)
  120. collection.load()
  121. res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1)
  122. return len(res) > 0
  123. def upsert(self, collection_name: str, items: List[VectorItem]):
  124. if not items:
  125. return
  126. mt_collection, resource_id = self._get_collection_and_resource_id(
  127. collection_name
  128. )
  129. dimension = len(items[0]["vector"])
  130. self._ensure_collection(mt_collection, dimension)
  131. collection = Collection(mt_collection)
  132. entities = [
  133. {
  134. "id": item["id"],
  135. "vector": item["vector"],
  136. "text": item["text"],
  137. "metadata": item["metadata"],
  138. RESOURCE_ID_FIELD: resource_id,
  139. }
  140. for item in items
  141. ]
  142. collection.insert(entities)
  143. collection.flush()
  144. def search(
  145. self, collection_name: str, vectors: List[List[float]], limit: int
  146. ) -> Optional[SearchResult]:
  147. if not vectors:
  148. return None
  149. mt_collection, resource_id = self._get_collection_and_resource_id(
  150. collection_name
  151. )
  152. if not utility.has_collection(mt_collection):
  153. return None
  154. collection = Collection(mt_collection)
  155. collection.load()
  156. search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
  157. results = collection.search(
  158. data=vectors,
  159. anns_field="vector",
  160. param=search_params,
  161. limit=limit,
  162. expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
  163. output_fields=["id", "text", "metadata"],
  164. )
  165. ids, documents, metadatas, distances = [], [], [], []
  166. for hits in results:
  167. batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
  168. for hit in hits:
  169. batch_ids.append(hit.entity.get("id"))
  170. batch_docs.append(hit.entity.get("text"))
  171. batch_metadatas.append(hit.entity.get("metadata"))
  172. batch_dists.append(hit.distance)
  173. ids.append(batch_ids)
  174. documents.append(batch_docs)
  175. metadatas.append(batch_metadatas)
  176. distances.append(batch_dists)
  177. return SearchResult(
  178. ids=ids, documents=documents, metadatas=metadatas, distances=distances
  179. )
  180. def delete(
  181. self,
  182. collection_name: str,
  183. ids: Optional[List[str]] = None,
  184. filter: Optional[Dict[str, Any]] = None,
  185. ):
  186. mt_collection, resource_id = self._get_collection_and_resource_id(
  187. collection_name
  188. )
  189. if not utility.has_collection(mt_collection):
  190. return
  191. collection = Collection(mt_collection)
  192. # Build expression
  193. expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
  194. if ids:
  195. # Milvus expects a string list for 'in' operator
  196. id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
  197. expr.append(f"id in [{id_list_str}]")
  198. if filter:
  199. for key, value in filter.items():
  200. expr.append(f"metadata['{key}'] == '{value}'")
  201. collection.delete(" and ".join(expr))
  202. def reset(self):
  203. for collection_name in self.shared_collections:
  204. if utility.has_collection(collection_name):
  205. utility.drop_collection(collection_name)
  206. def delete_collection(self, collection_name: str):
  207. mt_collection, resource_id = self._get_collection_and_resource_id(
  208. collection_name
  209. )
  210. if not utility.has_collection(mt_collection):
  211. return
  212. collection = Collection(mt_collection)
  213. collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
  214. def query(
  215. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  216. ) -> Optional[GetResult]:
  217. mt_collection, resource_id = self._get_collection_and_resource_id(
  218. collection_name
  219. )
  220. if not utility.has_collection(mt_collection):
  221. return None
  222. collection = Collection(mt_collection)
  223. collection.load()
  224. expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
  225. if filter:
  226. for key, value in filter.items():
  227. if isinstance(value, str):
  228. expr.append(f"metadata['{key}'] == '{value}'")
  229. else:
  230. expr.append(f"metadata['{key}'] == {value}")
  231. results = collection.query(
  232. expr=" and ".join(expr),
  233. output_fields=["id", "text", "metadata"],
  234. limit=limit,
  235. )
  236. ids = [res["id"] for res in results]
  237. documents = [res["text"] for res in results]
  238. metadatas = [res["metadata"] for res in results]
  239. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  240. def get(self, collection_name: str) -> Optional[GetResult]:
  241. return self.query(collection_name, filter={}, limit=None)
  242. def insert(self, collection_name: str, items: List[VectorItem]):
  243. return self.upsert(collection_name, items)