milvus_multitenancy.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import logging
  2. from typing import Optional, Tuple, List, Dict, Any
  3. from open_webui.config import (
  4. MILVUS_URI,
  5. MILVUS_TOKEN,
  6. MILVUS_DB,
  7. MILVUS_COLLECTION_PREFIX,
  8. MILVUS_INDEX_TYPE,
  9. MILVUS_METRIC_TYPE,
  10. MILVUS_HNSW_M,
  11. MILVUS_HNSW_EFCONSTRUCTION,
  12. MILVUS_IVF_FLAT_NLIST,
  13. )
  14. from open_webui.env import SRC_LOG_LEVELS
  15. from open_webui.retrieval.vector.main import (
  16. GetResult,
  17. SearchResult,
  18. VectorDBBase,
  19. VectorItem,
  20. )
  21. from pymilvus import (
  22. connections,
  23. utility,
  24. Collection,
  25. CollectionSchema,
  26. FieldSchema,
  27. DataType,
  28. )
  29. log = logging.getLogger(__name__)
  30. log.setLevel(SRC_LOG_LEVELS["RAG"])
  31. RESOURCE_ID_FIELD = "resource_id"
  32. class MilvusClient(VectorDBBase):
  33. def __init__(self):
  34. # Milvus collection names can only contain numbers, letters, and underscores.
  35. self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
  36. connections.connect(
  37. alias="default",
  38. uri=MILVUS_URI,
  39. token=MILVUS_TOKEN,
  40. db_name=MILVUS_DB,
  41. )
  42. # Main collection types for multi-tenancy
  43. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  44. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  45. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  46. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
  47. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
  48. self.shared_collections = [
  49. self.MEMORY_COLLECTION,
  50. self.KNOWLEDGE_COLLECTION,
  51. self.FILE_COLLECTION,
  52. self.WEB_SEARCH_COLLECTION,
  53. self.HASH_BASED_COLLECTION,
  54. ]
  55. def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
  56. """
  57. Maps the traditional collection name to multi-tenant collection and resource ID.
  58. WARNING: This mapping relies on current Open WebUI naming conventions for
  59. collection names. If Open WebUI changes how it generates collection names
  60. (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
  61. formats), this mapping will break and route data to incorrect collections.
  62. POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
  63. DATA MAPPING INSIDE THE DATABASE.
  64. """
  65. resource_id = collection_name
  66. if collection_name.startswith("user-memory-"):
  67. return self.MEMORY_COLLECTION, resource_id
  68. elif collection_name.startswith("file-"):
  69. return self.FILE_COLLECTION, resource_id
  70. elif collection_name.startswith("web-search-"):
  71. return self.WEB_SEARCH_COLLECTION, resource_id
  72. elif len(collection_name) == 63 and all(
  73. c in "0123456789abcdef" for c in collection_name
  74. ):
  75. return self.HASH_BASED_COLLECTION, resource_id
  76. else:
  77. return self.KNOWLEDGE_COLLECTION, resource_id
  78. def _create_shared_collection(self, mt_collection_name: str, dimension: int):
  79. fields = [
  80. FieldSchema(
  81. name="id",
  82. dtype=DataType.VARCHAR,
  83. is_primary=True,
  84. auto_id=False,
  85. max_length=36,
  86. ),
  87. FieldSchema(
  88. name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
  89. ),
  90. FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
  91. FieldSchema(name="metadata", dtype=DataType.JSON),
  92. FieldSchema(
  93. name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255
  94. ),
  95. ]
  96. schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
  97. collection = Collection(mt_collection_name, schema)
  98. index_params = {
  99. "metric_type": MILVUS_METRIC_TYPE,
  100. "index_type": MILVUS_INDEX_TYPE,
  101. "params": {},
  102. }
  103. if MILVUS_INDEX_TYPE == "HNSW":
  104. index_params["params"] = {
  105. "M": MILVUS_HNSW_M,
  106. "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
  107. }
  108. elif MILVUS_INDEX_TYPE == "IVF_FLAT":
  109. index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
  110. collection.create_index("vector", index_params)
  111. collection.create_index(RESOURCE_ID_FIELD)
  112. log.info(f"Created shared collection: {mt_collection_name}")
  113. return collection
  114. def _ensure_collection(self, mt_collection_name: str, dimension: int):
  115. if not utility.has_collection(mt_collection_name):
  116. self._create_shared_collection(mt_collection_name, dimension)
  117. def has_collection(self, collection_name: str) -> bool:
  118. mt_collection, resource_id = self._get_collection_and_resource_id(
  119. collection_name
  120. )
  121. if not utility.has_collection(mt_collection):
  122. return False
  123. collection = Collection(mt_collection)
  124. collection.load()
  125. res = collection.query(
  126. expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1
  127. )
  128. return len(res) > 0
  129. def upsert(self, collection_name: str, items: List[VectorItem]):
  130. if not items:
  131. return
  132. mt_collection, resource_id = self._get_collection_and_resource_id(
  133. collection_name
  134. )
  135. dimension = len(items[0]["vector"])
  136. self._ensure_collection(mt_collection, dimension)
  137. collection = Collection(mt_collection)
  138. entities = [
  139. {
  140. "id": item["id"],
  141. "vector": item["vector"],
  142. "text": item["text"],
  143. "metadata": item["metadata"],
  144. RESOURCE_ID_FIELD: resource_id,
  145. }
  146. for item in items
  147. ]
  148. collection.insert(entities)
  149. collection.flush()
  150. def search(
  151. self, collection_name: str, vectors: List[List[float]], limit: int
  152. ) -> Optional[SearchResult]:
  153. if not vectors:
  154. return None
  155. mt_collection, resource_id = self._get_collection_and_resource_id(
  156. collection_name
  157. )
  158. if not utility.has_collection(mt_collection):
  159. return None
  160. collection = Collection(mt_collection)
  161. collection.load()
  162. search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
  163. results = collection.search(
  164. data=vectors,
  165. anns_field="vector",
  166. param=search_params,
  167. limit=limit,
  168. expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
  169. output_fields=["id", "text", "metadata"],
  170. )
  171. ids, documents, metadatas, distances = [], [], [], []
  172. for hits in results:
  173. batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
  174. for hit in hits:
  175. batch_ids.append(hit.entity.get("id"))
  176. batch_docs.append(hit.entity.get("text"))
  177. batch_metadatas.append(hit.entity.get("metadata"))
  178. batch_dists.append(hit.distance)
  179. ids.append(batch_ids)
  180. documents.append(batch_docs)
  181. metadatas.append(batch_metadatas)
  182. distances.append(batch_dists)
  183. return SearchResult(
  184. ids=ids, documents=documents, metadatas=metadatas, distances=distances
  185. )
  186. def delete(
  187. self,
  188. collection_name: str,
  189. ids: Optional[List[str]] = None,
  190. filter: Optional[Dict[str, Any]] = None,
  191. ):
  192. mt_collection, resource_id = self._get_collection_and_resource_id(
  193. collection_name
  194. )
  195. if not utility.has_collection(mt_collection):
  196. return
  197. collection = Collection(mt_collection)
  198. # Build expression
  199. expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
  200. if ids:
  201. # Milvus expects a string list for 'in' operator
  202. id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
  203. expr.append(f"id in [{id_list_str}]")
  204. if filter:
  205. for key, value in filter.items():
  206. expr.append(f"metadata['{key}'] == '{value}'")
  207. collection.delete(" and ".join(expr))
  208. def reset(self):
  209. for collection_name in self.shared_collections:
  210. if utility.has_collection(collection_name):
  211. utility.drop_collection(collection_name)
  212. def delete_collection(self, collection_name: str):
  213. mt_collection, resource_id = self._get_collection_and_resource_id(
  214. collection_name
  215. )
  216. if not utility.has_collection(mt_collection):
  217. return
  218. collection = Collection(mt_collection)
  219. collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
  220. def query(
  221. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  222. ) -> Optional[GetResult]:
  223. mt_collection, resource_id = self._get_collection_and_resource_id(
  224. collection_name
  225. )
  226. if not utility.has_collection(mt_collection):
  227. return None
  228. collection = Collection(mt_collection)
  229. collection.load()
  230. expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
  231. if filter:
  232. for key, value in filter.items():
  233. if isinstance(value, str):
  234. expr.append(f"metadata['{key}'] == '{value}'")
  235. else:
  236. expr.append(f"metadata['{key}'] == {value}")
  237. results = collection.query(
  238. expr=" and ".join(expr),
  239. output_fields=["id", "text", "metadata"],
  240. limit=limit,
  241. )
  242. ids = [res["id"] for res in results]
  243. documents = [res["text"] for res in results]
  244. metadatas = [res["metadata"] for res in results]
  245. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  246. def get(self, collection_name: str) -> Optional[GetResult]:
  247. return self.query(collection_name, filter={}, limit=None)
  248. def insert(self, collection_name: str, items: List[VectorItem]):
  249. return self.upsert(collection_name, items)