milvus.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. import os # Added import
  2. from pymilvus import MilvusClient as Client
  3. from pymilvus import FieldSchema, DataType
  4. import json
  5. import logging
  6. from typing import Optional
  7. from open_webui.retrieval.vector.main import (
  8. VectorDBBase,
  9. VectorItem,
  10. SearchResult,
  11. GetResult,
  12. )
  13. from open_webui.config import (
  14. MILVUS_URI,
  15. MILVUS_DB,
  16. MILVUS_TOKEN,
  17. )
  18. from open_webui.env import SRC_LOG_LEVELS
  19. log = logging.getLogger(__name__)
  20. log.setLevel(SRC_LOG_LEVELS["RAG"])
  21. class MilvusClient(VectorDBBase):
  22. def __init__(self):
  23. self.collection_prefix = "open_webui"
  24. if MILVUS_TOKEN is None:
  25. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
  26. else:
  27. self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
  28. def _result_to_get_result(self, result) -> GetResult:
  29. ids = []
  30. documents = []
  31. metadatas = []
  32. for match in result:
  33. _ids = []
  34. _documents = []
  35. _metadatas = []
  36. for item in match:
  37. _ids.append(item.get("id"))
  38. _documents.append(item.get("data", {}).get("text"))
  39. _metadatas.append(item.get("metadata"))
  40. ids.append(_ids)
  41. documents.append(_documents)
  42. metadatas.append(_metadatas)
  43. return GetResult(
  44. **{
  45. "ids": ids,
  46. "documents": documents,
  47. "metadatas": metadatas,
  48. }
  49. )
  50. def _result_to_search_result(self, result) -> SearchResult:
  51. ids = []
  52. distances = []
  53. documents = []
  54. metadatas = []
  55. for match in result:
  56. _ids = []
  57. _distances = []
  58. _documents = []
  59. _metadatas = []
  60. for item in match:
  61. _ids.append(item.get("id"))
  62. # normalize milvus score from [-1, 1] to [0, 1] range
  63. # https://milvus.io/docs/de/metric.md
  64. _dist = (item.get("distance") + 1.0) / 2.0
  65. _distances.append(_dist)
  66. _documents.append(item.get("entity", {}).get("data", {}).get("text"))
  67. _metadatas.append(item.get("entity", {}).get("metadata"))
  68. ids.append(_ids)
  69. distances.append(_distances)
  70. documents.append(_documents)
  71. metadatas.append(_metadatas)
  72. return SearchResult(
  73. **{
  74. "ids": ids,
  75. "distances": distances,
  76. "documents": documents,
  77. "metadatas": metadatas,
  78. }
  79. )
  80. def _create_collection(self, collection_name: str, dimension: int):
  81. schema = self.client.create_schema(
  82. auto_id=False,
  83. enable_dynamic_field=True,
  84. )
  85. schema.add_field(
  86. field_name="id",
  87. datatype=DataType.VARCHAR,
  88. is_primary=True,
  89. max_length=65535,
  90. )
  91. schema.add_field(
  92. field_name="vector",
  93. datatype=DataType.FLOAT_VECTOR,
  94. dim=dimension,
  95. description="vector",
  96. )
  97. schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
  98. schema.add_field(
  99. field_name="metadata", datatype=DataType.JSON, description="metadata"
  100. )
  101. index_params = self.client.prepare_index_params()
  102. # Get index type from environment variable.
  103. # Milvus standalone (local mode) supports: FLAT, IVF_FLAT, AUTOINDEX.
  104. # HNSW is often preferred for performance but may require a clustered Milvus setup.
  105. # Defaulting to AUTOINDEX for broader compatibility, especially with Milvus standalone.
  106. default_index_type = "AUTOINDEX"
  107. milvus_index_type_env = os.getenv("MILVUS_INDEX_TYPE")
  108. if milvus_index_type_env:
  109. milvus_index_type = milvus_index_type_env.upper()
  110. log.info(f"Milvus index type from MILVUS_INDEX_TYPE env var: {milvus_index_type}")
  111. else:
  112. milvus_index_type = default_index_type
  113. log.info(f"MILVUS_INDEX_TYPE env var not set, defaulting to: {milvus_index_type}")
  114. index_creation_params = {}
  115. metric_type = os.getenv("MILVUS_METRIC_TYPE", "COSINE").upper() # Default to COSINE
  116. if milvus_index_type == "HNSW":
  117. # Parameters for HNSW
  118. m_env = os.getenv("MILVUS_HNSW_M", "16")
  119. ef_construction_env = os.getenv("MILVUS_HNSW_EFCONSTRUCTION", "100")
  120. try:
  121. m_val = int(m_env)
  122. ef_val = int(ef_construction_env)
  123. except ValueError:
  124. log.warning(f"Invalid HNSW params M='{m_env}' or efConstruction='{ef_construction_env}'. Defaulting to M=16, efConstruction=100.")
  125. m_val = 16
  126. ef_val = 100
  127. index_creation_params = {"M": m_val, "efConstruction": ef_val}
  128. log.info(f"Using HNSW index with metric {metric_type}, params: {index_creation_params}")
  129. elif milvus_index_type == "IVF_FLAT":
  130. # Parameters for IVF_FLAT
  131. nlist_env = os.getenv("MILVUS_IVF_FLAT_NLIST", "128")
  132. try:
  133. nlist = int(nlist_env)
  134. except ValueError:
  135. log.warning(f"Invalid MILVUS_IVF_FLAT_NLIST value '{nlist_env}'. Defaulting to 128.")
  136. nlist = 128
  137. index_creation_params = {"nlist": nlist}
  138. log.info(f"Using IVF_FLAT index with metric {metric_type}, params: {index_creation_params}")
  139. elif milvus_index_type == "FLAT":
  140. log.info(f"Using FLAT index with metric {metric_type} (no specific build-time params).")
  141. # No specific build-time parameters needed for FLAT
  142. elif milvus_index_type == "AUTOINDEX":
  143. log.info(f"Using AUTOINDEX with metric {metric_type} (params managed by Milvus).")
  144. # No specific build-time parameters needed for AUTOINDEX
  145. else:
  146. log.warning(
  147. f"Unsupported or unrecognized MILVUS_INDEX_TYPE: '{milvus_index_type}'. "
  148. f"Falling back to '{default_index_type}'. "
  149. f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX."
  150. )
  151. milvus_index_type = default_index_type # Fallback to a safe default
  152. # index_creation_params remains {} which is fine for AUTOINDEX/FLAT
  153. log.info(f"Fell back to {default_index_type} index with metric {metric_type}.")
  154. index_params.add_index(
  155. field_name="vector",
  156. index_type=milvus_index_type,
  157. metric_type=metric_type,
  158. params=index_creation_params,
  159. )
  160. self.client.create_collection(
  161. collection_name=f"{self.collection_prefix}_{collection_name}",
  162. schema=schema,
  163. index_params=index_params,
  164. )
  165. log.info(f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{milvus_index_type}' and metric '{metric_type}'.")
  166. def has_collection(self, collection_name: str) -> bool:
  167. # Check if the collection exists based on the collection name.
  168. collection_name = collection_name.replace("-", "_")
  169. return self.client.has_collection(
  170. collection_name=f"{self.collection_prefix}_{collection_name}"
  171. )
  172. def delete_collection(self, collection_name: str):
  173. # Delete the collection based on the collection name.
  174. collection_name = collection_name.replace("-", "_")
  175. return self.client.drop_collection(
  176. collection_name=f"{self.collection_prefix}_{collection_name}"
  177. )
  178. def search(
  179. self, collection_name: str, vectors: list[list[float | int]], limit: int
  180. ) -> Optional[SearchResult]:
  181. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
  182. collection_name = collection_name.replace("-", "_")
  183. # For some index types like IVF_FLAT, search params like nprobe can be set.
  184. # Example: search_params = {"nprobe": 10} if using IVF_FLAT
  185. # For simplicity, not adding configurable search_params here, but could be extended.
  186. result = self.client.search(
  187. collection_name=f"{self.collection_prefix}_{collection_name}",
  188. data=vectors,
  189. limit=limit,
  190. output_fields=["data", "metadata"],
  191. # search_params=search_params # Potentially add later if needed
  192. )
  193. return self._result_to_search_result(result)
  194. def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
  195. # Construct the filter string for querying
  196. collection_name = collection_name.replace("-", "_")
  197. if not self.has_collection(collection_name):
  198. log.warning(f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
  199. return None
  200. filter_string = " && ".join(
  201. [
  202. f'metadata["{key}"] == {json.dumps(value)}'
  203. for key, value in filter.items()
  204. ]
  205. )
  206. max_limit = 16383 # The maximum number of records per request
  207. all_results = []
  208. if limit is None:
  209. # Milvus default limit for query if not specified is 16384, but docs mention iteration.
  210. # Let's set a practical high number if "all" is intended, or handle true pagination.
  211. # For now, if limit is None, we'll fetch in batches up to a very large number.
  212. # This part could be refined based on expected use cases for "get all".
  213. # For this function signature, None implies "as many as possible" up to Milvus limits.
  214. limit = 16384 * 10 # A large number to signify fetching many, will be capped by actual data or max_limit per call.
  215. log.info(f"Limit not specified for query, fetching up to {limit} results in batches.")
  216. # Initialize offset and remaining to handle pagination
  217. offset = 0
  218. remaining = limit
  219. try:
  220. log.info(f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}")
  221. # Loop until there are no more items to fetch or the desired limit is reached
  222. while remaining > 0:
  223. current_fetch = min(max_limit, remaining if isinstance(remaining, int) else max_limit)
  224. log.debug(f"Querying with offset: {offset}, current_fetch: {current_fetch}")
  225. results = self.client.query(
  226. collection_name=f"{self.collection_prefix}_{collection_name}",
  227. filter=filter_string,
  228. output_fields=["id", "data", "metadata"], # Explicitly list needed fields. Vector not usually needed in query.
  229. limit=current_fetch,
  230. offset=offset,
  231. )
  232. if not results:
  233. log.debug("No more results from query.")
  234. break
  235. all_results.extend(results)
  236. results_count = len(results)
  237. log.debug(f"Fetched {results_count} results in this batch.")
  238. if isinstance(remaining, int):
  239. remaining -= results_count
  240. offset += results_count
  241. # Break the loop if the results returned are less than the requested fetch count (means end of data)
  242. if results_count < current_fetch:
  243. log.debug("Fetched less than requested, assuming end of results for this query.")
  244. break
  245. log.info(f"Total results from query: {len(all_results)}")
  246. return self._result_to_get_result([all_results])
  247. except Exception as e:
  248. log.exception(
  249. f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
  250. )
  251. return None
  252. def get(self, collection_name: str) -> Optional[GetResult]:
  253. # Get all the items in the collection. This can be very resource-intensive for large collections.
  254. collection_name = collection_name.replace("-", "_")
  255. log.warning(f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections.")
  256. # Using query with a trivial filter to get all items.
  257. # This will use the paginated query logic.
  258. return self.query(collection_name=collection_name, filter={}, limit=None)
  259. def insert(self, collection_name: str, items: list[VectorItem]):
  260. # Insert the items into the collection, if the collection does not exist, it will be created.
  261. collection_name = collection_name.replace("-", "_")
  262. if not self.client.has_collection(
  263. collection_name=f"{self.collection_prefix}_{collection_name}"
  264. ):
  265. log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.")
  266. if not items:
  267. log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.")
  268. raise ValueError("Cannot create Milvus collection without items to determine vector dimension.")
  269. self._create_collection(
  270. collection_name=collection_name, dimension=len(items[0]["vector"])
  271. )
  272. log.info(f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
  273. return self.client.insert(
  274. collection_name=f"{self.collection_prefix}_{collection_name}",
  275. data=[
  276. {
  277. "id": item["id"],
  278. "vector": item["vector"],
  279. "data": {"text": item["text"]},
  280. "metadata": item["metadata"],
  281. }
  282. for item in items
  283. ],
  284. )
  285. def upsert(self, collection_name: str, items: list[VectorItem]):
  286. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
  287. collection_name = collection_name.replace("-", "_")
  288. if not self.client.has_collection(
  289. collection_name=f"{self.collection_prefix}_{collection_name}"
  290. ):
  291. log.info(f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.")
  292. if not items:
  293. log.error(f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.")
  294. raise ValueError("Cannot create Milvus collection for upsert without items to determine vector dimension.")
  295. self._create_collection(
  296. collection_name=collection_name, dimension=len(items[0]["vector"])
  297. )
  298. log.info(f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.")
  299. return self.client.upsert(
  300. collection_name=f"{self.collection_prefix}_{collection_name}",
  301. data=[
  302. {
  303. "id": item["id"],
  304. "vector": item["vector"],
  305. "data": {"text": item["text"]},
  306. "metadata": item["metadata"],
  307. }
  308. for item in items
  309. ],
  310. )
  311. def delete(
  312. self,
  313. collection_name: str,
  314. ids: Optional[list[str]] = None,
  315. filter: Optional[dict] = None,
  316. ):
  317. # Delete the items from the collection based on the ids or filter.
  318. collection_name = collection_name.replace("-", "_")
  319. if not self.has_collection(collection_name):
  320. log.warning(f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}")
  321. return None
  322. if ids:
  323. log.info(f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}")
  324. return self.client.delete(
  325. collection_name=f"{self.collection_prefix}_{collection_name}",
  326. ids=ids,
  327. )
  328. elif filter:
  329. filter_string = " && ".join(
  330. [
  331. f'metadata["{key}"] == {json.dumps(value)}'
  332. for key, value in filter.items()
  333. ]
  334. )
  335. log.info(f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}")
  336. return self.client.delete(
  337. collection_name=f"{self.collection_prefix}_{collection_name}",
  338. filter=filter_string,
  339. )
  340. else:
  341. log.warning(f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.")
  342. return None
  343. def reset(self):
  344. # Resets the database. This will delete all collections and item entries that match the prefix.
  345. log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
  346. collection_names = self.client.list_collections()
  347. deleted_collections = []
  348. for collection_name_full in collection_names:
  349. if collection_name_full.startswith(self.collection_prefix):
  350. try:
  351. self.client.drop_collection(collection_name=collection_name_full)
  352. deleted_collections.append(collection_name_full)
  353. log.info(f"Deleted collection: {collection_name_full}")
  354. except Exception as e:
  355. log.error(f"Error deleting collection {collection_name_full}: {e}")
  356. log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")