milvus.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. import logging
  5. from typing import Optional
  6. from open_webui.retrieval.vector.utils import stringify_metadata
  7. from open_webui.retrieval.vector.main import (
  8. VectorDBBase,
  9. VectorItem,
  10. SearchResult,
  11. GetResult,
  12. )
  13. from open_webui.config import (
  14. MILVUS_URI,
  15. MILVUS_DB,
  16. MILVUS_TOKEN,
  17. MILVUS_INDEX_TYPE,
  18. MILVUS_METRIC_TYPE,
  19. MILVUS_HNSW_M,
  20. MILVUS_HNSW_EFCONSTRUCTION,
  21. MILVUS_IVF_FLAT_NLIST,
  22. )
  23. from open_webui.env import SRC_LOG_LEVELS
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["RAG"])
  26. class MilvusClient(VectorDBBase):
  27. def __init__(self):
  28. self.collection_prefix = "open_webui"
  29. if MILVUS_TOKEN is None:
  30. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
  31. else:
  32. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
  33. def _result_to_get_result(self, result) -> GetResult:
  34. ids = []
  35. documents = []
  36. metadatas = []
  37. for match in result:
  38. _ids = []
  39. _documents = []
  40. _metadatas = []
  41. for item in match:
  42. _ids.append(item.get("id"))
  43. _documents.append(item.get("data", {}).get("text"))
  44. _metadatas.append(item.get("metadata"))
  45. ids.append(_ids)
  46. documents.append(_documents)
  47. metadatas.append(_metadatas)
  48. return GetResult(
  49. **{
  50. "ids": ids,
  51. "documents": documents,
  52. "metadatas": metadatas,
  53. }
  54. )
  55. def _result_to_search_result(self, result) -> SearchResult:
  56. ids = []
  57. distances = []
  58. documents = []
  59. metadatas = []
  60. for match in result:
  61. _ids = []
  62. _distances = []
  63. _documents = []
  64. _metadatas = []
  65. for item in match:
  66. _ids.append(item.get("id"))
  67. # normalize milvus score from [-1, 1] to [0, 1] range
  68. # https://milvus.io/docs/de/metric.md
  69. _dist = (item.get("distance") + 1.0) / 2.0
  70. _distances.append(_dist)
  71. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  72. _metadatas.append(item.get("entity", {}).get("metadata"))
  73. ids.append(_ids)
  74. distances.append(_distances)
  75. documents.append(_documents)
  76. metadatas.append(_metadatas)
  77. return SearchResult(
  78. **{
  79. "ids": ids,
  80. "distances": distances,
  81. "documents": documents,
  82. "metadatas": metadatas,
  83. }
  84. )
  85. def _create_collection(self, collection_name: str, dimension: int):
  86. schema = self.client.create_schema(
  87. auto_id=False,
  88. enable_dynamic_field=True,
  89. )
  90. schema.add_field(
  91. field_name="id",
  92. datatype=DataType.VARCHAR,
  93. is_primary=True,
  94. max_length=65535,
  95. )
  96. schema.add_field(
  97. field_name="vector",
  98. datatype=DataType.FLOAT_VECTOR,
  99. dim=dimension,
  100. description="vector",
  101. )
  102. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  103. schema.add_field(
  104. field_name="metadata", datatype=DataType.JSON, description="metadata"
  105. )
  106. index_params = self.client.prepare_index_params()
  107. # Use configurations from config.py
  108. index_type = MILVUS_INDEX_TYPE.upper()
  109. metric_type = MILVUS_METRIC_TYPE.upper()
  110. log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
  111. index_creation_params = {}
  112. if index_type == "HNSW":
  113. index_creation_params = {
  114. "M": MILVUS_HNSW_M,
  115. "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
  116. }
  117. log.info(f"HNSW params: {index_creation_params}")
  118. elif index_type == "IVF_FLAT":
  119. index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
  120. log.info(f"IVF_FLAT params: {index_creation_params}")
  121. elif index_type in ["FLAT", "AUTOINDEX"]:
  122. log.info(f"Using {index_type} index with no specific build-time params.")
  123. else:
  124. log.warning(
  125. f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
  126. f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
  127. f"Milvus will use its default for the collection if this type is not directly supported for index creation."
  128. )
  129. # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
  130. # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
  131. index_params.add_index(
  132. field_name="vector",
  133. index_type=index_type,
  134. metric_type=metric_type,
  135. params=index_creation_params,
  136. )
  137. self.client.create_collection(
  138. collection_name=f"{self.collection_prefix}_{collection_name}",
  139. schema=schema,
  140. index_params=index_params,
  141. )
  142. log.info(
  143. f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
  144. )
  145. def has_collection(self, collection_name: str) -> bool:
  146. # Check if the collection exists based on the collection name.
  147. collection_name = collection_name.replace("-", "_")
  148. return self.client.has_collection(
  149. collection_name=f"{self.collection_prefix}_{collection_name}"
  150. )
  151. def delete_collection(self, collection_name: str):
  152. # Delete the collection based on the collection name.
  153. collection_name = collection_name.replace("-", "_")
  154. return self.client.drop_collection(
  155. collection_name=f"{self.collection_prefix}_{collection_name}"
  156. )
  157. def search(
  158. self, collection_name: str, vectors: list[list[float | int]], limit: int
  159. ) -> Optional[SearchResult]:
  160. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  161. collection_name = collection_name.replace("-", "_")
  162. # For some index types like IVF_FLAT, search params like nprobe can be set.
  163. # Example: search_params = {"nprobe": 10} if using IVF_FLAT
  164. # For simplicity, not adding configurable search_params here, but could be extended.
  165. result = self.client.search(
  166. collection_name=f"{self.collection_prefix}_{collection_name}",
  167. data=vectors,
  168. limit=limit,
  169. output_fields=["data", "metadata"],
  170. # search_params=search_params # Potentially add later if needed
  171. )
  172. return self._result_to_search_result(result)
  173. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  174. # Construct the filter string for querying
  175. collection_name = collection_name.replace("-", "_")
  176. if not self.has_collection(collection_name):
  177. log.warning(
  178. f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
  179. )
  180. return None
  181. filter_string = " && ".join(
  182. [
  183. f'metadata["{key}"] == {json.dumps(value)}'
  184. for key, value in filter.items()
  185. ]
  186. )
  187. max_limit = 16383 # The maximum number of records per request
  188. all_results = []
  189. if limit is None:
  190. # Milvus default limit for query if not specified is 16384, but docs mention iteration.
  191. # Let's set a practical high number if "all" is intended, or handle true pagination.
  192. # For now, if limit is None, we'll fetch in batches up to a very large number.
  193. # This part could be refined based on expected use cases for "get all".
  194. # For this function signature, None implies "as many as possible" up to Milvus limits.
  195. limit = (
  196. 16384 * 10
  197. ) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
  198. log.info(
  199. f"Limit not specified for query, fetching up to {limit} results in batches."
  200. )
  201. # Initialize offset and remaining to handle pagination
  202. offset = 0
  203. remaining = limit
  204. try:
  205. log.info(
  206. f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
  207. )
  208. # Loop until there are no more items to fetch or the desired limit is reached
  209. while remaining > 0:
  210. current_fetch = min(
  211. max_limit, remaining if isinstance(remaining, int) else max_limit
  212. )
  213. log.debug(
  214. f"Querying with offset: {offset}, current_fetch: {current_fetch}"
  215. )
  216. results = self.client.query(
  217. collection_name=f"{self.collection_prefix}_{collection_name}",
  218. filter=filter_string,
  219. output_fields=[
  220. "id",
  221. "data",
  222. "metadata",
  223. ], # Explicitly list needed fields. Vector not usually needed in query.
  224. limit=current_fetch,
  225. offset=offset,
  226. )
  227. if not results:
  228. log.debug("No more results from query.")
  229. break
  230. all_results.extend(results)
  231. results_count = len(results)
  232. log.debug(f"Fetched {results_count} results in this batch.")
  233. if isinstance(remaining, int):
  234. remaining -= results_count
  235. offset += results_count
  236. # Break the loop if the results returned are less than the requested fetch count (means end of data)
  237. if results_count < current_fetch:
  238. log.debug(
  239. "Fetched less than requested, assuming end of results for this query."
  240. )
  241. break
  242. log.info(f"Total results from query: {len(all_results)}")
  243. return self._result_to_get_result([all_results])
  244. except Exception as e:
  245. log.exception(
  246. f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
  247. )
  248. return None
  249. def get(self, collection_name: str) -> Optional[GetResult]:
  250. # Get all the items in the collection. This can be very resource-intensive for large collections.
  251. collection_name = collection_name.replace("-", "_")
  252. log.warning(
  253. f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
  254. )
  255. # Using query with a trivial filter to get all items.
  256. # This will use the paginated query logic.
  257. return self.query(collection_name=collection_name, filter={}, limit=None)
  258. def insert(self, collection_name: str, items: list[VectorItem]):
  259. # Insert the items into the collection, if the collection does not exist, it will be created.
  260. collection_name = collection_name.replace("-", "_")
  261. if not self.client.has_collection(
  262. collection_name=f"{self.collection_prefix}_{collection_name}"
  263. ):
  264. log.info(
  265. f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
  266. )
  267. if not items:
  268. log.error(
  269. f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
  270. )
  271. raise ValueError(
  272. "Cannot create Milvus collection without items to determine vector dimension."
  273. )
  274. self._create_collection(
  275. collection_name=collection_name, dimension=len(items[0]["vector"])
  276. )
  277. log.info(
  278. f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
  279. )
  280. return self.client.insert(
  281. collection_name=f"{self.collection_prefix}_{collection_name}",
  282. data=[
  283. {
  284. "id": item["id"],
  285. "vector": item["vector"],
  286. "data": {"text": item["text"]},
  287. "metadata": stringify_metadata(item["metadata"]),
  288. }
  289. for item in items
  290. ],
  291. )
  292. def upsert(self, collection_name: str, items: list[VectorItem]):
  293. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  294. collection_name = collection_name.replace("-", "_")
  295. if not self.client.has_collection(
  296. collection_name=f"{self.collection_prefix}_{collection_name}"
  297. ):
  298. log.info(
  299. f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
  300. )
  301. if not items:
  302. log.error(
  303. f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
  304. )
  305. raise ValueError(
  306. "Cannot create Milvus collection for upsert without items to determine vector dimension."
  307. )
  308. self._create_collection(
  309. collection_name=collection_name, dimension=len(items[0]["vector"])
  310. )
  311. log.info(
  312. f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
  313. )
  314. return self.client.upsert(
  315. collection_name=f"{self.collection_prefix}_{collection_name}",
  316. data=[
  317. {
  318. "id": item["id"],
  319. "vector": item["vector"],
  320. "data": {"text": item["text"]},
  321. "metadata": stringify_metadata(item["metadata"]),
  322. }
  323. for item in items
  324. ],
  325. )
  326. def delete(
  327. self,
  328. collection_name: str,
  329. ids: Optional[list[str]] = None,
  330. filter: Optional[dict] = None,
  331. ):
  332. # Delete the items from the collection based on the ids or filter.
  333. collection_name = collection_name.replace("-", "_")
  334. if not self.has_collection(collection_name):
  335. log.warning(
  336. f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
  337. )
  338. return None
  339. if ids:
  340. log.info(
  341. f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
  342. )
  343. return self.client.delete(
  344. collection_name=f"{self.collection_prefix}_{collection_name}",
  345. ids=ids,
  346. )
  347. elif filter:
  348. filter_string = " && ".join(
  349. [
  350. f'metadata["{key}"] == {json.dumps(value)}'
  351. for key, value in filter.items()
  352. ]
  353. )
  354. log.info(
  355. f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
  356. )
  357. return self.client.delete(
  358. collection_name=f"{self.collection_prefix}_{collection_name}",
  359. filter=filter_string,
  360. )
  361. else:
  362. log.warning(
  363. f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
  364. )
  365. return None
  366. def reset(self):
  367. # Resets the database. This will delete all collections and item entries that match the prefix.
  368. log.warning(
  369. f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
  370. )
  371. collection_names = self.client.list_collections()
  372. deleted_collections = []
  373. for collection_name_full in collection_names:
  374. if collection_name_full.startswith(self.collection_prefix):
  375. try:
  376. self.client.drop_collection(collection_name=collection_name_full)
  377. deleted_collections.append(collection_name_full)
  378. log.info(f"Deleted collection: {collection_name_full}")
  379. except Exception as e:
  380. log.error(f"Error deleting collection {collection_name_full}: {e}")
  381. log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")