milvus.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. from pymilvus import connections, Collection
  4. import json
  5. import logging
  6. from typing import Optional
  7. from open_webui.retrieval.vector.utils import stringify_metadata
  8. from open_webui.retrieval.vector.main import (
  9. VectorDBBase,
  10. VectorItem,
  11. SearchResult,
  12. GetResult,
  13. )
  14. from open_webui.config import (
  15. MILVUS_URI,
  16. MILVUS_DB,
  17. MILVUS_TOKEN,
  18. MILVUS_INDEX_TYPE,
  19. MILVUS_METRIC_TYPE,
  20. MILVUS_HNSW_M,
  21. MILVUS_HNSW_EFCONSTRUCTION,
  22. MILVUS_IVF_FLAT_NLIST,
  23. )
  24. from open_webui.env import SRC_LOG_LEVELS
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["RAG"])
  27. class MilvusClient(VectorDBBase):
  28. def __init__(self):
  29. self.collection_prefix = "open_webui"
  30. if MILVUS_TOKEN is None:
  31. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
  32. else:
  33. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
  34. def _result_to_get_result(self, result) -> GetResult:
  35. ids = []
  36. documents = []
  37. metadatas = []
  38. for match in result:
  39. _ids = []
  40. _documents = []
  41. _metadatas = []
  42. for item in match:
  43. _ids.append(item.get("id"))
  44. _documents.append(item.get("data", {}).get("text"))
  45. _metadatas.append(item.get("metadata"))
  46. ids.append(_ids)
  47. documents.append(_documents)
  48. metadatas.append(_metadatas)
  49. return GetResult(
  50. **{
  51. "ids": ids,
  52. "documents": documents,
  53. "metadatas": metadatas,
  54. }
  55. )
  56. def _result_to_search_result(self, result) -> SearchResult:
  57. ids = []
  58. distances = []
  59. documents = []
  60. metadatas = []
  61. for match in result:
  62. _ids = []
  63. _distances = []
  64. _documents = []
  65. _metadatas = []
  66. for item in match:
  67. _ids.append(item.get("id"))
  68. # normalize milvus score from [-1, 1] to [0, 1] range
  69. # https://milvus.io/docs/de/metric.md
  70. _dist = (item.get("distance") + 1.0) / 2.0
  71. _distances.append(_dist)
  72. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  73. _metadatas.append(item.get("entity", {}).get("metadata"))
  74. ids.append(_ids)
  75. distances.append(_distances)
  76. documents.append(_documents)
  77. metadatas.append(_metadatas)
  78. return SearchResult(
  79. **{
  80. "ids": ids,
  81. "distances": distances,
  82. "documents": documents,
  83. "metadatas": metadatas,
  84. }
  85. )
  86. def _create_collection(self, collection_name: str, dimension: int):
  87. schema = self.client.create_schema(
  88. auto_id=False,
  89. enable_dynamic_field=True,
  90. )
  91. schema.add_field(
  92. field_name="id",
  93. datatype=DataType.VARCHAR,
  94. is_primary=True,
  95. max_length=65535,
  96. )
  97. schema.add_field(
  98. field_name="vector",
  99. datatype=DataType.FLOAT_VECTOR,
  100. dim=dimension,
  101. description="vector",
  102. )
  103. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  104. schema.add_field(
  105. field_name="metadata", datatype=DataType.JSON, description="metadata"
  106. )
  107. index_params = self.client.prepare_index_params()
  108. # Use configurations from config.py
  109. index_type = MILVUS_INDEX_TYPE.upper()
  110. metric_type = MILVUS_METRIC_TYPE.upper()
  111. log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
  112. index_creation_params = {}
  113. if index_type == "HNSW":
  114. index_creation_params = {
  115. "M": MILVUS_HNSW_M,
  116. "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
  117. }
  118. log.info(f"HNSW params: {index_creation_params}")
  119. elif index_type == "IVF_FLAT":
  120. index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
  121. log.info(f"IVF_FLAT params: {index_creation_params}")
  122. elif index_type in ["FLAT", "AUTOINDEX"]:
  123. log.info(f"Using {index_type} index with no specific build-time params.")
  124. else:
  125. log.warning(
  126. f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
  127. f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
  128. f"Milvus will use its default for the collection if this type is not directly supported for index creation."
  129. )
  130. # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
  131. # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
  132. index_params.add_index(
  133. field_name="vector",
  134. index_type=index_type,
  135. metric_type=metric_type,
  136. params=index_creation_params,
  137. )
  138. self.client.create_collection(
  139. collection_name=f"{self.collection_prefix}_{collection_name}",
  140. schema=schema,
  141. index_params=index_params,
  142. )
  143. log.info(
  144. f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
  145. )
  146. def has_collection(self, collection_name: str) -> bool:
  147. # Check if the collection exists based on the collection name.
  148. collection_name = collection_name.replace("-", "_")
  149. return self.client.has_collection(
  150. collection_name=f"{self.collection_prefix}_{collection_name}"
  151. )
  152. def delete_collection(self, collection_name: str):
  153. # Delete the collection based on the collection name.
  154. collection_name = collection_name.replace("-", "_")
  155. return self.client.drop_collection(
  156. collection_name=f"{self.collection_prefix}_{collection_name}"
  157. )
  158. def search(
  159. self, collection_name: str, vectors: list[list[float | int]], limit: int
  160. ) -> Optional[SearchResult]:
  161. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  162. collection_name = collection_name.replace("-", "_")
  163. # For some index types like IVF_FLAT, search params like nprobe can be set.
  164. # Example: search_params = {"nprobe": 10} if using IVF_FLAT
  165. # For simplicity, not adding configurable search_params here, but could be extended.
  166. result = self.client.search(
  167. collection_name=f"{self.collection_prefix}_{collection_name}",
  168. data=vectors,
  169. limit=limit,
  170. output_fields=["data", "metadata"],
  171. # search_params=search_params # Potentially add later if needed
  172. )
  173. return self._result_to_search_result(result)
  174. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  175. connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
  176. # Construct the filter string for querying
  177. collection_name = collection_name.replace("-", "_")
  178. if not self.has_collection(collection_name):
  179. log.warning(
  180. f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
  181. )
  182. return None
  183. filter_string = " && ".join(
  184. [
  185. f'metadata["{key}"] == {json.dumps(value)}'
  186. for key, value in filter.items()
  187. ]
  188. )
  189. collection = Collection(f"{self.collection_prefix}_{collection_name}")
  190. collection.load()
  191. all_results = []
  192. try:
  193. log.info(
  194. f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
  195. )
  196. iterator = collection.query_iterator(
  197. filter=filter_string,
  198. output_fields=[
  199. "id",
  200. "data",
  201. "metadata",
  202. ],
  203. limit=limit, # Pass the limit directly; None means no limit.
  204. )
  205. while True:
  206. result = iterator.next()
  207. if not result:
  208. iterator.close()
  209. break
  210. all_results += result
  211. log.info(f"Total results from query: {len(all_results)}")
  212. return self._result_to_get_result([all_results])
  213. except Exception as e:
  214. log.exception(
  215. f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
  216. )
  217. return None
  218. def get(self, collection_name: str) -> Optional[GetResult]:
  219. # Get all the items in the collection. This can be very resource-intensive for large collections.
  220. collection_name = collection_name.replace("-", "_")
  221. log.warning(
  222. f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
  223. )
  224. # Using query with a trivial filter to get all items.
  225. # This will use the paginated query logic.
  226. return self.query(collection_name=collection_name, filter={}, limit=None)
  227. def insert(self, collection_name: str, items: list[VectorItem]):
  228. # Insert the items into the collection, if the collection does not exist, it will be created.
  229. collection_name = collection_name.replace("-", "_")
  230. if not self.client.has_collection(
  231. collection_name=f"{self.collection_prefix}_{collection_name}"
  232. ):
  233. log.info(
  234. f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
  235. )
  236. if not items:
  237. log.error(
  238. f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
  239. )
  240. raise ValueError(
  241. "Cannot create Milvus collection without items to determine vector dimension."
  242. )
  243. self._create_collection(
  244. collection_name=collection_name, dimension=len(items[0]["vector"])
  245. )
  246. log.info(
  247. f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
  248. )
  249. return self.client.insert(
  250. collection_name=f"{self.collection_prefix}_{collection_name}",
  251. data=[
  252. {
  253. "id": item["id"],
  254. "vector": item["vector"],
  255. "data": {"text": item["text"]},
  256. "metadata": stringify_metadata(item["metadata"]),
  257. }
  258. for item in items
  259. ],
  260. )
  261. def upsert(self, collection_name: str, items: list[VectorItem]):
  262. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  263. collection_name = collection_name.replace("-", "_")
  264. if not self.client.has_collection(
  265. collection_name=f"{self.collection_prefix}_{collection_name}"
  266. ):
  267. log.info(
  268. f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
  269. )
  270. if not items:
  271. log.error(
  272. f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
  273. )
  274. raise ValueError(
  275. "Cannot create Milvus collection for upsert without items to determine vector dimension."
  276. )
  277. self._create_collection(
  278. collection_name=collection_name, dimension=len(items[0]["vector"])
  279. )
  280. log.info(
  281. f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
  282. )
  283. return self.client.upsert(
  284. collection_name=f"{self.collection_prefix}_{collection_name}",
  285. data=[
  286. {
  287. "id": item["id"],
  288. "vector": item["vector"],
  289. "data": {"text": item["text"]},
  290. "metadata": stringify_metadata(item["metadata"]),
  291. }
  292. for item in items
  293. ],
  294. )
  295. def delete(
  296. self,
  297. collection_name: str,
  298. ids: Optional[list[str]] = None,
  299. filter: Optional[dict] = None,
  300. ):
  301. # Delete the items from the collection based on the ids or filter.
  302. collection_name = collection_name.replace("-", "_")
  303. if not self.has_collection(collection_name):
  304. log.warning(
  305. f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
  306. )
  307. return None
  308. if ids:
  309. log.info(
  310. f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
  311. )
  312. return self.client.delete(
  313. collection_name=f"{self.collection_prefix}_{collection_name}",
  314. ids=ids,
  315. )
  316. elif filter:
  317. filter_string = " && ".join(
  318. [
  319. f'metadata["{key}"] == {json.dumps(value)}'
  320. for key, value in filter.items()
  321. ]
  322. )
  323. log.info(
  324. f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
  325. )
  326. return self.client.delete(
  327. collection_name=f"{self.collection_prefix}_{collection_name}",
  328. filter=filter_string,
  329. )
  330. else:
  331. log.warning(
  332. f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
  333. )
  334. return None
  335. def reset(self):
  336. # Resets the database. This will delete all collections and item entries that match the prefix.
  337. log.warning(
  338. f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
  339. )
  340. collection_names = self.client.list_collections()
  341. deleted_collections = []
  342. for collection_name_full in collection_names:
  343. if collection_name_full.startswith(self.collection_prefix):
  344. try:
  345. self.client.drop_collection(collection_name=collection_name_full)
  346. deleted_collections.append(collection_name_full)
  347. log.info(f"Deleted collection: {collection_name_full}")
  348. except Exception as e:
  349. log.error(f"Error deleting collection {collection_name_full}: {e}")
  350. log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")