milvus_multitenancy.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import logging
  2. from typing import Optional, Tuple, List, Dict, Any
  3. from open_webui.config import (
  4. MILVUS_URI,
  5. MILVUS_TOKEN,
  6. MILVUS_DB,
  7. MILVUS_COLLECTION_PREFIX,
  8. MILVUS_INDEX_TYPE,
  9. MILVUS_METRIC_TYPE,
  10. MILVUS_HNSW_M,
  11. MILVUS_HNSW_EFCONSTRUCTION,
  12. MILVUS_IVF_FLAT_NLIST,
  13. )
  14. from open_webui.env import SRC_LOG_LEVELS
  15. from open_webui.retrieval.vector.main import (
  16. GetResult,
  17. SearchResult,
  18. VectorDBBase,
  19. VectorItem,
  20. )
  21. from pymilvus import (
  22. connections,
  23. utility,
  24. Collection,
  25. CollectionSchema,
  26. FieldSchema,
  27. DataType,
  28. )
  29. log = logging.getLogger(__name__)
  30. log.setLevel(SRC_LOG_LEVELS["RAG"])
  31. RESOURCE_ID_FIELD = "resource_id"
  32. class MilvusClient(VectorDBBase):
  33. def __init__(self):
  34. # Milvus collection names can only contain numbers, letters, and underscores.
  35. self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
  36. connections.connect(
  37. alias="default",
  38. uri=MILVUS_URI,
  39. token=MILVUS_TOKEN,
  40. db_name=MILVUS_DB,
  41. )
  42. # Main collection types for multi-tenancy
  43. self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
  44. self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
  45. self.FILE_COLLECTION = f"{self.collection_prefix}_files"
  46. self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
  47. self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
  48. self.shared_collections = [
  49. self.MEMORY_COLLECTION,
  50. self.KNOWLEDGE_COLLECTION,
  51. self.FILE_COLLECTION,
  52. self.WEB_SEARCH_COLLECTION,
  53. self.HASH_BASED_COLLECTION,
  54. ]
  55. def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
  56. """
  57. Maps the traditional collection name to multi-tenant collection and resource ID.
  58. """
  59. resource_id = collection_name
  60. if collection_name.startswith("user-memory-"):
  61. return self.MEMORY_COLLECTION, resource_id
  62. elif collection_name.startswith("file-"):
  63. return self.FILE_COLLECTION, resource_id
  64. elif collection_name.startswith("web-search-"):
  65. return self.WEB_SEARCH_COLLECTION, resource_id
  66. elif len(collection_name) == 63 and all(
  67. c in "0123456789abcdef" for c in collection_name
  68. ):
  69. return self.HASH_BASED_COLLECTION, resource_id
  70. else:
  71. return self.KNOWLEDGE_COLLECTION, resource_id
  72. def _create_shared_collection(self, mt_collection_name: str, dimension: int):
  73. fields = [
  74. FieldSchema(
  75. name="id",
  76. dtype=DataType.VARCHAR,
  77. is_primary=True,
  78. auto_id=False,
  79. max_length=36,
  80. ),
  81. FieldSchema(
  82. name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension
  83. ),
  84. FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
  85. FieldSchema(name="metadata", dtype=DataType.JSON),
  86. FieldSchema(
  87. name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255
  88. ),
  89. ]
  90. schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
  91. collection = Collection(mt_collection_name, schema)
  92. index_params = {
  93. "metric_type": MILVUS_METRIC_TYPE,
  94. "index_type": MILVUS_INDEX_TYPE,
  95. "params": {},
  96. }
  97. if MILVUS_INDEX_TYPE == "HNSW":
  98. index_params["params"] = {
  99. "M": MILVUS_HNSW_M,
  100. "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
  101. }
  102. elif MILVUS_INDEX_TYPE == "IVF_FLAT":
  103. index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
  104. collection.create_index("vector", index_params)
  105. collection.create_index(RESOURCE_ID_FIELD)
  106. log.info(f"Created shared collection: {mt_collection_name}")
  107. return collection
  108. def _ensure_collection(self, mt_collection_name: str, dimension: int):
  109. if not utility.has_collection(mt_collection_name):
  110. self._create_shared_collection(mt_collection_name, dimension)
  111. def has_collection(self, collection_name: str) -> bool:
  112. mt_collection, resource_id = self._get_collection_and_resource_id(
  113. collection_name
  114. )
  115. if not utility.has_collection(mt_collection):
  116. return False
  117. collection = Collection(mt_collection)
  118. collection.load()
  119. res = collection.query(
  120. expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1
  121. )
  122. return len(res) > 0
  123. def upsert(self, collection_name: str, items: List[VectorItem]):
  124. if not items:
  125. return
  126. mt_collection, resource_id = self._get_collection_and_resource_id(
  127. collection_name
  128. )
  129. dimension = len(items[0]["vector"])
  130. self._ensure_collection(mt_collection, dimension)
  131. collection = Collection(mt_collection)
  132. entities = [
  133. {
  134. "id": item["id"],
  135. "vector": item["vector"],
  136. "text": item["text"],
  137. "metadata": item["metadata"],
  138. RESOURCE_ID_FIELD: resource_id,
  139. }
  140. for item in items
  141. ]
  142. collection.insert(entities)
  143. collection.flush()
  144. def search(
  145. self, collection_name: str, vectors: List[List[float]], limit: int
  146. ) -> Optional[SearchResult]:
  147. if not vectors:
  148. return None
  149. mt_collection, resource_id = self._get_collection_and_resource_id(
  150. collection_name
  151. )
  152. if not utility.has_collection(mt_collection):
  153. return None
  154. collection = Collection(mt_collection)
  155. collection.load()
  156. search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
  157. results = collection.search(
  158. data=vectors,
  159. anns_field="vector",
  160. param=search_params,
  161. limit=limit,
  162. expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
  163. output_fields=["id", "text", "metadata"],
  164. )
  165. ids, documents, metadatas, distances = [], [], [], []
  166. for hits in results:
  167. batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
  168. for hit in hits:
  169. batch_ids.append(hit.entity.get("id"))
  170. batch_docs.append(hit.entity.get("text"))
  171. batch_metadatas.append(hit.entity.get("metadata"))
  172. batch_dists.append(hit.distance)
  173. ids.append(batch_ids)
  174. documents.append(batch_docs)
  175. metadatas.append(batch_metadatas)
  176. distances.append(batch_dists)
  177. return SearchResult(
  178. ids=ids, documents=documents, metadatas=metadatas, distances=distances
  179. )
  180. def delete(
  181. self,
  182. collection_name: str,
  183. ids: Optional[List[str]] = None,
  184. filter: Optional[Dict[str, Any]] = None,
  185. ):
  186. mt_collection, resource_id = self._get_collection_and_resource_id(
  187. collection_name
  188. )
  189. if not utility.has_collection(mt_collection):
  190. return
  191. collection = Collection(mt_collection)
  192. # Build expression
  193. expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
  194. if ids:
  195. # Milvus expects a string list for 'in' operator
  196. id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
  197. expr.append(f"id in [{id_list_str}]")
  198. if filter:
  199. for key, value in filter.items():
  200. expr.append(f"metadata['{key}'] == '{value}'")
  201. collection.delete(" and ".join(expr))
  202. def reset(self):
  203. for collection_name in self.shared_collections:
  204. if utility.has_collection(collection_name):
  205. utility.drop_collection(collection_name)
  206. def delete_collection(self, collection_name: str):
  207. mt_collection, resource_id = self._get_collection_and_resource_id(
  208. collection_name
  209. )
  210. if not utility.has_collection(mt_collection):
  211. return
  212. collection = Collection(mt_collection)
  213. collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
  214. def query(
  215. self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
  216. ) -> Optional[GetResult]:
  217. mt_collection, resource_id = self._get_collection_and_resource_id(
  218. collection_name
  219. )
  220. if not utility.has_collection(mt_collection):
  221. return None
  222. collection = Collection(mt_collection)
  223. collection.load()
  224. expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
  225. if filter:
  226. for key, value in filter.items():
  227. if isinstance(value, str):
  228. expr.append(f"metadata['{key}'] == '{value}'")
  229. else:
  230. expr.append(f"metadata['{key}'] == {value}")
  231. results = collection.query(
  232. expr=" and ".join(expr),
  233. output_fields=["id", "text", "metadata"],
  234. limit=limit,
  235. )
  236. ids = [res["id"] for res in results]
  237. documents = [res["text"] for res in results]
  238. metadatas = [res["metadata"] for res in results]
  239. return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
  240. def get(self, collection_name: str) -> Optional[GetResult]:
  241. return self.query(collection_name, filter={}, limit=None)
  242. def insert(self, collection_name: str, items: List[VectorItem]):
  243. return self.upsert(collection_name, items)