milvus.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. import logging
  5. from typing import Optional
  6. from open_webui.retrieval.vector.main import (
  7. VectorDBBase,
  8. VectorItem,
  9. SearchResult,
  10. GetResult,
  11. )
  12. from open_webui.config import (
  13. MILVUS_URI,
  14. MILVUS_DB,
  15. MILVUS_TOKEN,
  16. MILVUS_INDEX_TYPE,
  17. MILVUS_METRIC_TYPE,
  18. MILVUS_HNSW_M,
  19. MILVUS_HNSW_EFCONSTRUCTION,
  20. MILVUS_IVF_FLAT_NLIST,
  21. )
  22. from open_webui.env import SRC_LOG_LEVELS
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["RAG"])
  25. class MilvusClient(VectorDBBase):
  26. def __init__(self):
  27. self.collection_prefix = "open_webui"
  28. if MILVUS_TOKEN is None:
  29. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
  30. else:
  31. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
  32. def _result_to_get_result(self, result) -> GetResult:
  33. ids = []
  34. documents = []
  35. metadatas = []
  36. for match in result:
  37. _ids = []
  38. _documents = []
  39. _metadatas = []
  40. for item in match:
  41. _ids.append(item.get("id"))
  42. _documents.append(item.get("data", {}).get("text"))
  43. _metadatas.append(item.get("metadata"))
  44. ids.append(_ids)
  45. documents.append(_documents)
  46. metadatas.append(_metadatas)
  47. return GetResult(
  48. **{
  49. "ids": ids,
  50. "documents": documents,
  51. "metadatas": metadatas,
  52. }
  53. )
  54. def _result_to_search_result(self, result) -> SearchResult:
  55. ids = []
  56. distances = []
  57. documents = []
  58. metadatas = []
  59. for match in result:
  60. _ids = []
  61. _distances = []
  62. _documents = []
  63. _metadatas = []
  64. for item in match:
  65. _ids.append(item.get("id"))
  66. # normalize milvus score from [-1, 1] to [0, 1] range
  67. # https://milvus.io/docs/de/metric.md
  68. _dist = (item.get("distance") + 1.0) / 2.0
  69. _distances.append(_dist)
  70. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  71. _metadatas.append(item.get("entity", {}).get("metadata"))
  72. ids.append(_ids)
  73. distances.append(_distances)
  74. documents.append(_documents)
  75. metadatas.append(_metadatas)
  76. return SearchResult(
  77. **{
  78. "ids": ids,
  79. "distances": distances,
  80. "documents": documents,
  81. "metadatas": metadatas,
  82. }
  83. )
  84. def _create_collection(self, collection_name: str, dimension: int):
  85. schema = self.client.create_schema(
  86. auto_id=False,
  87. enable_dynamic_field=True,
  88. )
  89. schema.add_field(
  90. field_name="id",
  91. datatype=DataType.VARCHAR,
  92. is_primary=True,
  93. max_length=65535,
  94. )
  95. schema.add_field(
  96. field_name="vector",
  97. datatype=DataType.FLOAT_VECTOR,
  98. dim=dimension,
  99. description="vector",
  100. )
  101. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  102. schema.add_field(
  103. field_name="metadata", datatype=DataType.JSON, description="metadata"
  104. )
  105. index_params = self.client.prepare_index_params()
  106. # Use configurations from config.py
  107. index_type = MILVUS_INDEX_TYPE.upper()
  108. metric_type = MILVUS_METRIC_TYPE.upper()
  109. log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
  110. index_creation_params = {}
  111. if index_type == "HNSW":
  112. index_creation_params = {
  113. "M": MILVUS_HNSW_M,
  114. "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
  115. }
  116. log.info(f"HNSW params: {index_creation_params}")
  117. elif index_type == "IVF_FLAT":
  118. index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
  119. log.info(f"IVF_FLAT params: {index_creation_params}")
  120. elif index_type in ["FLAT", "AUTOINDEX"]:
  121. log.info(f"Using {index_type} index with no specific build-time params.")
  122. else:
  123. log.warning(
  124. f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
  125. f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
  126. f"Milvus will use its default for the collection if this type is not directly supported for index creation."
  127. )
  128. # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
  129. # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
  130. index_params.add_index(
  131. field_name="vector",
  132. index_type=index_type,
  133. metric_type=metric_type,
  134. params=index_creation_params,
  135. )
  136. self.client.create_collection(
  137. collection_name=f"{self.collection_prefix}_{collection_name}",
  138. schema=schema,
  139. index_params=index_params,
  140. )
  141. log.info(
  142. f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
  143. )
  144. def has_collection(self, collection_name: str) -> bool:
  145. # Check if the collection exists based on the collection name.
  146. collection_name = collection_name.replace("-", "_")
  147. return self.client.has_collection(
  148. collection_name=f"{self.collection_prefix}_{collection_name}"
  149. )
  150. def delete_collection(self, collection_name: str):
  151. # Delete the collection based on the collection name.
  152. collection_name = collection_name.replace("-", "_")
  153. return self.client.drop_collection(
  154. collection_name=f"{self.collection_prefix}_{collection_name}"
  155. )
  156. def search(
  157. self, collection_name: str, vectors: list[list[float | int]], limit: int
  158. ) -> Optional[SearchResult]:
  159. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  160. collection_name = collection_name.replace("-", "_")
  161. # For some index types like IVF_FLAT, search params like nprobe can be set.
  162. # Example: search_params = {"nprobe": 10} if using IVF_FLAT
  163. # For simplicity, not adding configurable search_params here, but could be extended.
  164. result = self.client.search(
  165. collection_name=f"{self.collection_prefix}_{collection_name}",
  166. data=vectors,
  167. limit=limit,
  168. output_fields=["data", "metadata"],
  169. # search_params=search_params # Potentially add later if needed
  170. )
  171. return self._result_to_search_result(result)
  172. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  173. # Construct the filter string for querying
  174. collection_name = collection_name.replace("-", "_")
  175. if not self.has_collection(collection_name):
  176. log.warning(
  177. f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
  178. )
  179. return None
  180. filter_string = " && ".join(
  181. [
  182. f'metadata["{key}"] == {json.dumps(value)}'
  183. for key, value in filter.items()
  184. ]
  185. )
  186. max_limit = 16383 # The maximum number of records per request
  187. all_results = []
  188. if limit is None:
  189. # Milvus default limit for query if not specified is 16384, but docs mention iteration.
  190. # Let's set a practical high number if "all" is intended, or handle true pagination.
  191. # For now, if limit is None, we'll fetch in batches up to a very large number.
  192. # This part could be refined based on expected use cases for "get all".
  193. # For this function signature, None implies "as many as possible" up to Milvus limits.
  194. limit = (
  195. 16384 * 10
  196. ) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
  197. log.info(
  198. f"Limit not specified for query, fetching up to {limit} results in batches."
  199. )
  200. # Initialize offset and remaining to handle pagination
  201. offset = 0
  202. remaining = limit
  203. try:
  204. log.info(
  205. f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
  206. )
  207. # Loop until there are no more items to fetch or the desired limit is reached
  208. while remaining > 0:
  209. current_fetch = min(
  210. max_limit, remaining if isinstance(remaining, int) else max_limit
  211. )
  212. log.debug(
  213. f"Querying with offset: {offset}, current_fetch: {current_fetch}"
  214. )
  215. results = self.client.query(
  216. collection_name=f"{self.collection_prefix}_{collection_name}",
  217. filter=filter_string,
  218. output_fields=[
  219. "id",
  220. "data",
  221. "metadata",
  222. ], # Explicitly list needed fields. Vector not usually needed in query.
  223. limit=current_fetch,
  224. offset=offset,
  225. )
  226. if not results:
  227. log.debug("No more results from query.")
  228. break
  229. all_results.extend(results)
  230. results_count = len(results)
  231. log.debug(f"Fetched {results_count} results in this batch.")
  232. if isinstance(remaining, int):
  233. remaining -= results_count
  234. offset += results_count
  235. # Break the loop if the results returned are less than the requested fetch count (means end of data)
  236. if results_count < current_fetch:
  237. log.debug(
  238. "Fetched less than requested, assuming end of results for this query."
  239. )
  240. break
  241. log.info(f"Total results from query: {len(all_results)}")
  242. return self._result_to_get_result([all_results])
  243. except Exception as e:
  244. log.exception(
  245. f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
  246. )
  247. return None
  248. def get(self, collection_name: str) -> Optional[GetResult]:
  249. # Get all the items in the collection. This can be very resource-intensive for large collections.
  250. collection_name = collection_name.replace("-", "_")
  251. log.warning(
  252. f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
  253. )
  254. # Using query with a trivial filter to get all items.
  255. # This will use the paginated query logic.
  256. return self.query(collection_name=collection_name, filter={}, limit=None)
  257. def insert(self, collection_name: str, items: list[VectorItem]):
  258. # Insert the items into the collection, if the collection does not exist, it will be created.
  259. collection_name = collection_name.replace("-", "_")
  260. if not self.client.has_collection(
  261. collection_name=f"{self.collection_prefix}_{collection_name}"
  262. ):
  263. log.info(
  264. f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
  265. )
  266. if not items:
  267. log.error(
  268. f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
  269. )
  270. raise ValueError(
  271. "Cannot create Milvus collection without items to determine vector dimension."
  272. )
  273. self._create_collection(
  274. collection_name=collection_name, dimension=len(items[0]["vector"])
  275. )
  276. log.info(
  277. f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
  278. )
  279. return self.client.insert(
  280. collection_name=f"{self.collection_prefix}_{collection_name}",
  281. data=[
  282. {
  283. "id": item["id"],
  284. "vector": item["vector"],
  285. "data": {"text": item["text"]},
  286. "metadata": item["metadata"],
  287. }
  288. for item in items
  289. ],
  290. )
  291. def upsert(self, collection_name: str, items: list[VectorItem]):
  292. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  293. collection_name = collection_name.replace("-", "_")
  294. if not self.client.has_collection(
  295. collection_name=f"{self.collection_prefix}_{collection_name}"
  296. ):
  297. log.info(
  298. f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
  299. )
  300. if not items:
  301. log.error(
  302. f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
  303. )
  304. raise ValueError(
  305. "Cannot create Milvus collection for upsert without items to determine vector dimension."
  306. )
  307. self._create_collection(
  308. collection_name=collection_name, dimension=len(items[0]["vector"])
  309. )
  310. log.info(
  311. f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
  312. )
  313. return self.client.upsert(
  314. collection_name=f"{self.collection_prefix}_{collection_name}",
  315. data=[
  316. {
  317. "id": item["id"],
  318. "vector": item["vector"],
  319. "data": {"text": item["text"]},
  320. "metadata": item["metadata"],
  321. }
  322. for item in items
  323. ],
  324. )
  325. def delete(
  326. self,
  327. collection_name: str,
  328. ids: Optional[list[str]] = None,
  329. filter: Optional[dict] = None,
  330. ):
  331. # Delete the items from the collection based on the ids or filter.
  332. collection_name = collection_name.replace("-", "_")
  333. if not self.has_collection(collection_name):
  334. log.warning(
  335. f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
  336. )
  337. return None
  338. if ids:
  339. log.info(
  340. f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
  341. )
  342. return self.client.delete(
  343. collection_name=f"{self.collection_prefix}_{collection_name}",
  344. ids=ids,
  345. )
  346. elif filter:
  347. filter_string = " && ".join(
  348. [
  349. f'metadata["{key}"] == {json.dumps(value)}'
  350. for key, value in filter.items()
  351. ]
  352. )
  353. log.info(
  354. f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
  355. )
  356. return self.client.delete(
  357. collection_name=f"{self.collection_prefix}_{collection_name}",
  358. filter=filter_string,
  359. )
  360. else:
  361. log.warning(
  362. f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
  363. )
  364. return None
  365. def reset(self):
  366. # Resets the database. This will delete all collections and item entries that match the prefix.
  367. log.warning(
  368. f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
  369. )
  370. collection_names = self.client.list_collections()
  371. deleted_collections = []
  372. for collection_name_full in collection_names:
  373. if collection_name_full.startswith(self.collection_prefix):
  374. try:
  375. self.client.drop_collection(collection_name=collection_name_full)
  376. deleted_collections.append(collection_name_full)
  377. log.info(f"Deleted collection: {collection_name_full}")
  378. except Exception as e:
  379. log.error(f"Error deleting collection {collection_name_full}: {e}")
  380. log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")