milvus.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from pymilvus import MilvusClient as Client
  2. from pymilvus import FieldSchema, DataType
  3. import json
  4. import logging
  5. from typing import Optional
  6. from open_webui.retrieval.vector.main import (
  7. VectorDBBase,
  8. VectorItem,
  9. SearchResult,
  10. GetResult,
  11. )
  12. from open_webui.config import (
  13. MILVUS_URI,
  14. MILVUS_DB,
  15. MILVUS_TOKEN,
  16. MILVUS_INDEX_TYPE,
  17. MILVUS_METRIC_TYPE,
  18. MILVUS_HNSW_M,
  19. MILVUS_HNSW_EFCONSTRUCTION,
  20. MILVUS_IVF_FLAT_NLIST,
  21. )
  22. from open_webui.env import SRC_LOG_LEVELS
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["RAG"])
  25. class MilvusClient(VectorDBBase):
  26. def __init__(self):
  27. self.collection_prefix = "open_webui"
  28. if MILVUS_TOKEN is None:
  29. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
  30. else:
  31. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
  32. def _result_to_get_result(self, result) -> GetResult:
  33. ids = []
  34. documents = []
  35. metadatas = []
  36. for match in result:
  37. _ids = []
  38. _documents = []
  39. _metadatas = []
  40. for item in match:
  41. _ids.append(item.get("id"))
  42. _documents.append(item.get("data", {}).get("text"))
  43. _metadatas.append(item.get("metadata"))
  44. ids.append(_ids)
  45. documents.append(_documents)
  46. metadatas.append(_metadatas)
  47. return GetResult(
  48. **{
  49. "ids": ids,
  50. "documents": documents,
  51. "metadatas": metadatas,
  52. }
  53. )
  54. def _result_to_search_result(self, result) -> SearchResult:
  55. ids = []
  56. distances = []
  57. documents = []
  58. metadatas = []
  59. for match in result:
  60. _ids = []
  61. _distances = []
  62. _documents = []
  63. _metadatas = []
  64. for item in match:
  65. _ids.append(item.get("id"))
  66. # normalize milvus score from [-1, 1] to [0, 1] range
  67. # https://milvus.io/docs/de/metric.md
  68. _dist = (item.get("distance") + 1.0) / 2.0
  69. _distances.append(_dist)
  70. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  71. _metadatas.append(item.get("entity", {}).get("metadata"))
  72. ids.append(_ids)
  73. distances.append(_distances)
  74. documents.append(_documents)
  75. metadatas.append(_metadatas)
  76. return SearchResult(
  77. **{
  78. "ids": ids,
  79. "distances": distances,
  80. "documents": documents,
  81. "metadatas": metadatas,
  82. }
  83. )
  84. def _create_collection(self, collection_name: str, dimension: int):
  85. schema = self.client.create_schema(
  86. auto_id=False,
  87. enable_dynamic_field=True,
  88. )
  89. schema.add_field(
  90. field_name="id",
  91. datatype=DataType.VARCHAR,
  92. is_primary=True,
  93. max_length=65535,
  94. )
  95. schema.add_field(
  96. field_name="vector",
  97. datatype=DataType.FLOAT_VECTOR,
  98. dim=dimension,
  99. description="vector",
  100. )
  101. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  102. schema.add_field(
  103. field_name="metadata", datatype=DataType.JSON, description="metadata"
  104. )
  105. index_params = self.client.prepare_index_params()
  106. # Use configurations from config.py
  107. index_type = MILVUS_INDEX_TYPE.upper()
  108. metric_type = MILVUS_METRIC_TYPE.upper()
  109. log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
  110. index_creation_params = {}
  111. if index_type == "HNSW":
  112. index_creation_params = {"M": MILVUS_HNSW_M, "efConstruction": MILVUS_HNSW_EFCONSTRUCTION}
  113. log.info(f"HNSW params: {index_creation_params}")
  114. elif index_type == "IVF_FLAT":
  115. index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
  116. log.info(f"IVF_FLAT params: {index_creation_params}")
  117. elif index_type in ["FLAT", "AUTOINDEX"]:
  118. log.info(f"Using {index_type} index with no specific build-time params.")
  119. else:
  120. log.warning(
  121. f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
  122. f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
  123. f"Milvus will use its default for the collection if this type is not directly supported for index creation."
  124. )
  125. # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
  126. # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
  127. index_params.add_index(
  128. field_name="vector",
  129. index_type=index_type,
  130. metric_type=metric_type,
  131. params=index_creation_params,
  132. )
  133. self.client.create_collection(
  134. collection_name=f"{self.collection_prefix}_{collection_name}",
  135. schema=schema,
  136. index_params=index_params,
  137. )
  138. log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'.")
  139. def has_collection(self, collection_name: str) -> bool:
  140. # Check if the collection exists based on the collection name.
  141. collection_name = collection_name.replace("-", "_")
  142. return self.client.has_collection(
  143. collection_name=f"{self.collection_prefix}_{collection_name}"
  144. )
  145. def delete_collection(self, collection_name: str):
  146. # Delete the collection based on the collection name.
  147. collection_name = collection_name.replace("-", "_")
  148. return self.client.drop_collection(
  149. collection_name=f"{self.collection_prefix}_{collection_name}"
  150. )
  151. def search(
  152. self, collection_name: str, vectors: list[list[float | int]], limit: int
  153. ) -> Optional[SearchResult]:
  154. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  155. collection_name = collection_name.replace("-", "_")
  156. # For some index types like IVF_FLAT, search params like nprobe can be set.
  157. # Example: search_params = {"nprobe": 10} if using IVF_FLAT
  158. # For simplicity, not adding configurable search_params here, but could be extended.
  159. result = self.client.search(
  160. collection_name=f"{self.collection_prefix}_{collection_name}",
  161. data=vectors,
  162. limit=limit,
  163. output_fields=["data", "metadata"],
  164. # search_params=search_params # Potentially add later if needed
  165. )
  166. return self._result_to_search_result(result)
  167. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  168. # Construct the filter string for querying
  169. collection_name = collection_name.replace("-", "_")
  170. if not self.has_collection(collection_name):
  171. log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
  172. return None
  173. filter_string = " && ".join(
  174. [
  175. f'metadata["{key}"] == {json.dumps(value)}'
  176. for key, value in filter.items()
  177. ]
  178. )
  179. max_limit = 16383 # The maximum number of records per request
  180. all_results = []
  181. if limit is None:
  182. # Milvus default limit for query if not specified is 16384, but docs mention iteration.
  183. # Let's set a practical high number if "all" is intended, or handle true pagination.
  184. # For now, if limit is None, we'll fetch in batches up to a very large number.
  185. # This part could be refined based on expected use cases for "get all".
  186. # For this function signature, None implies "as many as possible" up to Milvus limits.
  187. limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
  188. log.info(f"Limit not specified for query, fetching up to {limit} results in batches.")
  189. # Initialize offset and remaining to handle pagination
  190. offset = 0
  191. remaining = limit
  192. try:
  193. log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}")
  194. # Loop until there are no more items to fetch or the desired limit is reached
  195. while remaining > 0:
  196. current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit)
  197. log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}")
  198. results = self.client.query(
  199. collection_name=f"{self.collection_prefix}_{collection_name}",
  200. filter=filter_string,
  201. output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query.
  202. limit=current_fetch,
  203. offset=offset,
  204. )
  205. if not results:
  206. log.debug("No more results from query.")
  207. break
  208. all_results.extend(results)
  209. results_count = len(results)
  210. log.debug(f"Fetched {results_count} results in this batch.")
  211. if isinstance(remaining, int):
  212. remaining -= results_count
  213. offset += results_count
  214. # Break the loop if the results returned are less than the requested fetch count (means end of data)
  215. if results_count < current_fetch:
  216. log.debug("Fetched less than requested, assuming end of results for this query.")
  217. break
  218. log.info(f"Total results from query: {len(all_results)}")
  219. return self._result_to_get_result([all_results])
  220. except Exception as e:
  221. log.exception(
  222. f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
  223. )
  224. return None
  225. def get(self, collection_name: str) -> Optional[GetResult]:
  226. # Get all the items in the collection. This can be very resource-intensive for large collections.
  227. collection_name = collection_name.replace("-", "_")
  228. log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.")
  229. # Using query with a trivial filter to get all items.
  230. # This will use the paginated query logic.
  231. return self.query(collection_name=collection_name, filter={}, limit=None)
  232. def insert(self, collection_name: str, items: list[VectorItem]):
  233. # Insert the items into the collection, if the collection does not exist, it will be created.
  234. collection_name = collection_name.replace("-", "_")
  235. if not self.client.has_collection(
  236. collection_name=f"{self.collection_prefix}_{collection_name}"
  237. ):
  238. log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.")
  239. if not items:
  240. log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.")
  241. raise ValueError("Cannot create Milvus collection without items to determine vector dimension.")
  242. self._create_collection(
  243. collection_name=collection_name, dimension=len(items[0]["vector"])
  244. )
  245. log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
  246. return self.client.insert(
  247. collection_name=f"{self.collection_prefix}_{collection_name}",
  248. data=[
  249. {
  250. "id": item["id"],
  251. "vector": item["vector"],
  252. "data": {"text": item["text"]},
  253. "metadata": item["metadata"],
  254. }
  255. for item in items
  256. ],
  257. )
  258. def upsert(self, collection_name: str, items: list[VectorItem]):
  259. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  260. collection_name = collection_name.replace("-", "_")
  261. if not self.client.has_collection(
  262. collection_name=f"{self.collection_prefix}_{collection_name}"
  263. ):
  264. log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.")
  265. if not items:
  266. log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.")
  267. raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.")
  268. self._create_collection(
  269. collection_name=collection_name, dimension=len(items[0]["vector"])
  270. )
  271. log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
  272. return self.client.upsert(
  273. collection_name=f"{self.collection_prefix}_{collection_name}",
  274. data=[
  275. {
  276. "id": item["id"],
  277. "vector": item["vector"],
  278. "data": {"text": item["text"]},
  279. "metadata": item["metadata"],
  280. }
  281. for item in items
  282. ],
  283. )
  284. def delete(
  285. self,
  286. collection_name: str,
  287. ids: Optional[list[str]] = None,
  288. filter: Optional[dict] = None,
  289. ):
  290. # Delete the items from the collection based on the ids or filter.
  291. collection_name = collection_name.replace("-", "_")
  292. if not self.has_collection(collection_name):
  293. log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
  294. return None
  295. if ids:
  296. log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}")
  297. return self.client.delete(
  298. collection_name=f"{self.collection_prefix}_{collection_name}",
  299. ids=ids,
  300. )
  301. elif filter:
  302. filter_string = " && ".join(
  303. [
  304. f'metadata["{key}"] == {json.dumps(value)}'
  305. for key, value in filter.items()
  306. ]
  307. )
  308. log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}")
  309. return self.client.delete(
  310. collection_name=f"{self.collection_prefix}_{collection_name}",
  311. filter=filter_string,
  312. )
  313. else:
  314. log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.")
  315. return None
  316. def reset(self):
  317. # Resets the database. This will delete all collections and item entries that match the prefix.
  318. log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
  319. collection_names = self.client.list_collections()
  320. deleted_collections = []
  321. for collection_name_full in collection_names:
  322. if collection_name_full.startswith(self.collection_prefix):
  323. try:
  324. self.client.drop_collection(collection_name=collection_name_full)
  325. deleted_collections.append(collection_name_full)
  326. log.info(f"Deleted collection: {collection_name_full}")
  327. except Exception as e:
  328. log.error(f"Error deleting collection {collection_name_full}: {e}")
  329. log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")