s3vector.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. from open_webui.retrieval.vector.main import VectorDBBase, VectorItem, GetResult, SearchResult
  2. from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
  3. from open_webui.env import SRC_LOG_LEVELS
  4. from typing import List, Optional, Dict, Any, Union
  5. import logging
  6. import boto3
  7. log = logging.getLogger(__name__)
  8. log.setLevel(SRC_LOG_LEVELS["RAG"])
  9. class S3VectorClient(VectorDBBase):
  10. """
  11. AWS S3 Vector integration for Open WebUI Knowledge.
  12. Assumes AWS credentials are available via environment variables or IAM roles.
  13. """
  14. def __init__(self):
  15. self.bucket_name = S3_VECTOR_BUCKET_NAME
  16. self.region = S3_VECTOR_REGION
  17. self.client = boto3.client("s3vectors", region_name=self.region)
  18. def _create_index(self, index_name: str, dimension: int, data_type: str = "float32", distance_metric: str = "cosine"):
  19. """
  20. Create a new index in the S3 vector bucket for the given collection if it does not exist.
  21. """
  22. if self.has_collection(index_name):
  23. return
  24. try:
  25. self.client.create_index(
  26. vectorBucketName=self.bucket_name,
  27. indexName=index_name,
  28. dataType=data_type,
  29. dimension=dimension,
  30. distanceMetric=distance_metric,
  31. )
  32. log.info(f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})")
  33. except Exception as e:
  34. log.error(f"Error creating S3 index '{index_name}': {e}")
  35. def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]:
  36. """
  37. Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
  38. If AWS S3 Vector feature starts supporting more than 10 keys, this should be adjusted, and preferably removed.
  39. Limitation is documented here: https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-vectors-limitations.html
  40. """
  41. if not isinstance(metadata, dict) or len(metadata) <= 10:
  42. return metadata
  43. # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
  44. important_keys = [
  45. 'text', # The actual document content
  46. 'file_id', # File ID
  47. 'source', # Document source file
  48. 'title', # Document title
  49. 'page', # Page number
  50. 'total_pages', # Total pages in document
  51. 'embedding_config', # Embedding configuration
  52. 'created_by', # User who created it
  53. 'name', # Document name
  54. 'hash', # Content hash
  55. ]
  56. filtered_metadata = {}
  57. # First, add important keys if they exist
  58. for key in important_keys:
  59. if key in metadata:
  60. filtered_metadata[key] = metadata[key]
  61. if len(filtered_metadata) >= 10:
  62. break
  63. # If we still have room, add other keys
  64. if len(filtered_metadata) < 10:
  65. for key, value in metadata.items():
  66. if key not in filtered_metadata:
  67. filtered_metadata[key] = value
  68. if len(filtered_metadata) >= 10:
  69. break
  70. log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys")
  71. return filtered_metadata
  72. def has_collection(self, collection_name: str) -> bool:
  73. """
  74. Check if a vector index (collection) exists in the S3 vector bucket.
  75. """
  76. try:
  77. response = self.client.list_indexes(vectorBucketName=self.bucket_name)
  78. indexes = response.get("indexes", [])
  79. return any(idx.get("indexName") == collection_name for idx in indexes)
  80. except Exception as e:
  81. log.error(f"Error listing indexes: {e}")
  82. return False
  83. def delete_collection(self, collection_name: str) -> None:
  84. """
  85. Delete an entire S3 Vector index/collection.
  86. """
  87. if not self.has_collection(collection_name):
  88. log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
  89. return
  90. try:
  91. log.info(f"Deleting collection '{collection_name}'")
  92. self.client.delete_index(
  93. vectorBucketName=self.bucket_name,
  94. indexName=collection_name
  95. )
  96. log.info(f"Successfully deleted collection '{collection_name}'")
  97. except Exception as e:
  98. log.error(f"Error deleting collection '{collection_name}': {e}")
  99. raise
  100. def insert(self, collection_name: str, items: List[Dict[str, Any]]) -> None:
  101. """
  102. Insert vector items into the S3 Vector index. Create index if it does not exist.
  103. Supports both knowledge collection indexes and file-specific indexes (file-{file_id}).
  104. """
  105. if not items:
  106. log.warning("No items to insert")
  107. return
  108. dimension = len(items[0]["vector"])
  109. try:
  110. if not self.has_collection(collection_name):
  111. log.info(f"Index '{collection_name}' does not exist. Creating index.")
  112. self._create_index(
  113. index_name=collection_name,
  114. dimension=dimension,
  115. data_type="float32",
  116. distance_metric="cosine",
  117. )
  118. # Prepare vectors for insertion
  119. vectors = []
  120. for item in items:
  121. # Ensure vector data is in the correct format for S3 Vector API
  122. vector_data = item["vector"]
  123. if isinstance(vector_data, list):
  124. # Convert list to float32 values as required by S3 Vector API
  125. vector_data = [float(x) for x in vector_data]
  126. # Prepare metadata, ensuring the text field is preserved
  127. metadata = item.get("metadata", {}).copy()
  128. # Add the text field to metadata so it's available for retrieval
  129. if "text" in item:
  130. metadata["text"] = item["text"]
  131. else:
  132. log.warning(f"No 'text' field found in item with ID: {item.get('id')}")
  133. # Filter metadata to comply with S3 Vector API limit of 10 keys
  134. metadata = self._filter_metadata(metadata, item["id"])
  135. vectors.append({
  136. "key": item["id"],
  137. "data": {
  138. "float32": vector_data
  139. },
  140. "metadata": metadata
  141. })
  142. # Insert vectors
  143. self.client.put_vectors(
  144. vectorBucketName=self.bucket_name,
  145. indexName=collection_name,
  146. vectors=vectors
  147. )
  148. log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
  149. except Exception as e:
  150. log.error(f"Error inserting vectors: {e}")
  151. raise
  152. def upsert(self, collection_name: str, items: List[Dict[str, Any]]) -> None:
  153. """
  154. Insert or update vector items in the S3 Vector index. Create index if it does not exist.
  155. """
  156. if not items:
  157. log.warning("No items to upsert")
  158. return
  159. dimension = len(items[0]["vector"])
  160. log.info(f"Upsert dimension: {dimension}")
  161. try:
  162. if not self.has_collection(collection_name):
  163. log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.")
  164. self._create_index(
  165. index_name=collection_name,
  166. dimension=dimension,
  167. data_type="float32",
  168. distance_metric="cosine",
  169. )
  170. # Prepare vectors for upsert
  171. vectors = []
  172. for item in items:
  173. # Ensure vector data is in the correct format for S3 Vector API
  174. vector_data = item["vector"]
  175. if isinstance(vector_data, list):
  176. # Convert list to float32 values as required by S3 Vector API
  177. vector_data = [float(x) for x in vector_data]
  178. # Prepare metadata, ensuring the text field is preserved
  179. metadata = item.get("metadata", {}).copy()
  180. # Add the text field to metadata so it's available for retrieval
  181. if "text" in item:
  182. metadata["text"] = item["text"]
  183. # Filter metadata to comply with S3 Vector API limit of 10 keys
  184. metadata = self._filter_metadata(metadata, item["id"])
  185. vectors.append({
  186. "key": item["id"],
  187. "data": {
  188. "float32": vector_data
  189. },
  190. "metadata": metadata
  191. })
  192. # Upsert vectors (using put_vectors for upsert semantics)
  193. log.info(f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}")
  194. self.client.put_vectors(
  195. vectorBucketName=self.bucket_name,
  196. indexName=collection_name,
  197. vectors=vectors
  198. )
  199. log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
  200. except Exception as e:
  201. log.error(f"Error upserting vectors: {e}")
  202. raise
  203. def search(self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int) -> Optional[SearchResult]:
  204. """
  205. Search for similar vectors in a collection using multiple query vectors.
  206. Uses S3 Vector's query_vectors API to perform similarity search.
  207. Args:
  208. collection_name: Name of the collection to search in
  209. vectors: List of query vectors to search with
  210. limit: Maximum number of results to return per query
  211. Returns:
  212. SearchResult containing IDs, documents, metadatas, and distances
  213. """
  214. if not self.has_collection(collection_name):
  215. log.warning(f"Collection '{collection_name}' does not exist")
  216. return None
  217. if not vectors:
  218. log.warning("No query vectors provided")
  219. return None
  220. try:
  221. log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}")
  222. # Initialize result lists
  223. all_ids = []
  224. all_documents = []
  225. all_metadatas = []
  226. all_distances = []
  227. # Process each query vector
  228. for i, query_vector in enumerate(vectors):
  229. log.debug(f"Processing query vector {i+1}/{len(vectors)}")
  230. # Prepare the query vector in S3 Vector format
  231. query_vector_dict = {
  232. 'float32': [float(x) for x in query_vector]
  233. }
  234. # Call S3 Vector query API
  235. response = self.client.query_vectors(
  236. vectorBucketName=self.bucket_name,
  237. indexName=collection_name,
  238. topK=limit,
  239. queryVector=query_vector_dict,
  240. returnMetadata=True,
  241. returnDistance=True
  242. )
  243. # Process results for this query
  244. query_ids = []
  245. query_documents = []
  246. query_metadatas = []
  247. query_distances = []
  248. result_vectors = response.get('vectors', [])
  249. for vector in result_vectors:
  250. vector_id = vector.get('key')
  251. vector_metadata = vector.get('metadata', {})
  252. vector_distance = vector.get('distance', 0.0)
  253. # Extract document text from metadata
  254. document_text = ""
  255. if isinstance(vector_metadata, dict):
  256. # Get the text field first (highest priority)
  257. document_text = vector_metadata.get('text')
  258. if not document_text:
  259. # Fallback to other possible text fields
  260. document_text = (vector_metadata.get('content') or
  261. vector_metadata.get('document') or
  262. vector_id)
  263. else:
  264. document_text = vector_id
  265. query_ids.append(vector_id)
  266. query_documents.append(document_text)
  267. query_metadatas.append(vector_metadata)
  268. query_distances.append(vector_distance)
  269. # Add this query's results to the overall results
  270. all_ids.append(query_ids)
  271. all_documents.append(query_documents)
  272. all_metadatas.append(query_metadatas)
  273. all_distances.append(query_distances)
  274. log.info(f"Search completed. Found results for {len(all_ids)} queries")
  275. # Return SearchResult format
  276. return SearchResult(
  277. ids=all_ids if all_ids else None,
  278. documents=all_documents if all_documents else None,
  279. metadatas=all_metadatas if all_metadatas else None,
  280. distances=all_distances if all_distances else None
  281. )
  282. except Exception as e:
  283. log.error(f"Error searching collection '{collection_name}': {str(e)}")
  284. # Handle specific AWS exceptions
  285. if hasattr(e, 'response') and 'Error' in e.response:
  286. error_code = e.response['Error']['Code']
  287. if error_code == 'NotFoundException':
  288. log.warning(f"Collection '{collection_name}' not found")
  289. return None
  290. elif error_code == 'ValidationException':
  291. log.error(f"Invalid query vector dimensions or parameters")
  292. return None
  293. elif error_code == 'AccessDeniedException':
  294. log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
  295. return None
  296. raise
  297. def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
  298. """
  299. Query vectors from a collection using metadata filter.
  300. For S3 Vector, this uses the list_vectors API with metadata filters.
  301. Note: S3 Vector supports metadata filtering, but the exact filter syntax may vary.
  302. Args:
  303. collection_name: Name of the collection to query
  304. filter: Dictionary containing metadata filter conditions
  305. limit: Maximum number of results to return (optional)
  306. Returns:
  307. GetResult containing IDs, documents, and metadatas
  308. """
  309. if not self.has_collection(collection_name):
  310. log.warning(f"Collection '{collection_name}' does not exist")
  311. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  312. if not filter:
  313. log.warning("No filter provided, returning all vectors")
  314. return self.get(collection_name)
  315. try:
  316. log.info(f"Querying collection '{collection_name}' with filter: {filter}")
  317. # For S3 Vector, we need to use list_vectors and then filter results
  318. # Since S3 Vector may not support complex server-side filtering,
  319. # we'll retrieve all vectors and filter client-side
  320. # Get all vectors first
  321. all_vectors_result = self.get(collection_name)
  322. if not all_vectors_result or not all_vectors_result.ids:
  323. log.warning("No vectors found in collection")
  324. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  325. # Extract the lists from the result
  326. all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
  327. all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else []
  328. all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
  329. # Apply client-side filtering
  330. filtered_ids = []
  331. filtered_documents = []
  332. filtered_metadatas = []
  333. for i, metadata in enumerate(all_metadatas):
  334. if self._matches_filter(metadata, filter):
  335. if i < len(all_ids):
  336. filtered_ids.append(all_ids[i])
  337. if i < len(all_documents):
  338. filtered_documents.append(all_documents[i])
  339. filtered_metadatas.append(metadata)
  340. # Apply limit if specified
  341. if limit and len(filtered_ids) >= limit:
  342. break
  343. log.info(f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total")
  344. # Return GetResult format
  345. if filtered_ids:
  346. return GetResult(ids=[filtered_ids], documents=[filtered_documents], metadatas=[filtered_metadatas])
  347. else:
  348. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  349. except Exception as e:
  350. log.error(f"Error querying collection '{collection_name}': {str(e)}")
  351. # Handle specific AWS exceptions
  352. if hasattr(e, 'response') and 'Error' in e.response:
  353. error_code = e.response['Error']['Code']
  354. if error_code == 'NotFoundException':
  355. log.warning(f"Collection '{collection_name}' not found")
  356. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  357. elif error_code == 'AccessDeniedException':
  358. log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
  359. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  360. raise
  361. def get(self, collection_name: str) -> Optional[GetResult]:
  362. """
  363. Retrieve all vectors from a collection.
  364. Uses S3 Vector's list_vectors API to get all vectors with their data and metadata.
  365. Handles pagination automatically to retrieve all vectors.
  366. """
  367. if not self.has_collection(collection_name):
  368. log.warning(f"Collection '{collection_name}' does not exist")
  369. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  370. try:
  371. log.info(f"Retrieving all vectors from collection '{collection_name}'")
  372. # Initialize result lists
  373. all_ids = []
  374. all_documents = []
  375. all_metadatas = []
  376. # Handle pagination
  377. next_token = None
  378. while True:
  379. # Prepare request parameters
  380. request_params = {
  381. 'vectorBucketName': self.bucket_name,
  382. 'indexName': collection_name,
  383. 'returnData': False, # Don't include vector data (not needed for get)
  384. 'returnMetadata': True, # Include metadata
  385. 'maxResults': 500 # Use reasonable page size
  386. }
  387. if next_token:
  388. request_params['nextToken'] = next_token
  389. # Call S3 Vector API
  390. response = self.client.list_vectors(**request_params)
  391. # Process vectors in this page
  392. vectors = response.get('vectors', [])
  393. for vector in vectors:
  394. vector_id = vector.get('key')
  395. vector_data = vector.get('data', {})
  396. vector_metadata = vector.get('metadata', {})
  397. # Extract the actual vector array
  398. vector_array = vector_data.get('float32', [])
  399. # For documents, we try to extract text from metadata or use the vector ID
  400. document_text = ""
  401. if isinstance(vector_metadata, dict):
  402. # Get the text field first (highest priority)
  403. document_text = vector_metadata.get('text')
  404. if not document_text:
  405. # Fallback to other possible text fields
  406. document_text = (vector_metadata.get('content') or
  407. vector_metadata.get('document') or
  408. vector_id)
  409. # Log the actual content for debugging
  410. log.debug(f"Document text preview (first 200 chars): {str(document_text)[:200]}")
  411. else:
  412. document_text = vector_id
  413. all_ids.append(vector_id)
  414. all_documents.append(document_text)
  415. all_metadatas.append(vector_metadata)
  416. # Check if there are more pages
  417. next_token = response.get('nextToken')
  418. if not next_token:
  419. break
  420. log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'")
  421. # Return in GetResult format
  422. # The Open WebUI GetResult expects lists of lists, so we wrap each list
  423. if all_ids:
  424. return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas])
  425. else:
  426. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  427. except Exception as e:
  428. log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}")
  429. # Handle specific AWS exceptions
  430. if hasattr(e, 'response') and 'Error' in e.response:
  431. error_code = e.response['Error']['Code']
  432. if error_code == 'NotFoundException':
  433. log.warning(f"Collection '{collection_name}' not found")
  434. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  435. elif error_code == 'AccessDeniedException':
  436. log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
  437. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  438. raise
  439. def delete(self, collection_name: str, ids: Optional[List[str]] = None, filter: Optional[Dict] = None) -> None:
  440. """
  441. Delete vectors by ID or filter from a collection.
  442. For S3 Vector, we support deletion by IDs. Filter-based deletion requires querying first.
  443. For knowledge collections, also handles cleanup of related file-specific collections.
  444. """
  445. if not self.has_collection(collection_name):
  446. log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
  447. return
  448. # Check if this is a knowledge collection (not file-specific)
  449. is_knowledge_collection = not collection_name.startswith("file-")
  450. try:
  451. if ids:
  452. # Delete by specific vector IDs/keys
  453. log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'")
  454. self.client.delete_vectors(
  455. vectorBucketName=self.bucket_name,
  456. indexName=collection_name,
  457. keys=ids
  458. )
  459. log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
  460. elif filter:
  461. # Handle filter-based deletion
  462. log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}")
  463. # If this is a knowledge collection and we have a file_id filter,
  464. # also clean up the corresponding file-specific collection
  465. if is_knowledge_collection and "file_id" in filter:
  466. file_id = filter["file_id"]
  467. file_collection_name = f"file-{file_id}"
  468. if self.has_collection(file_collection_name):
  469. log.info(f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates")
  470. self.delete_collection(file_collection_name)
  471. # For the main collection, implement query-then-delete
  472. # First, query to get IDs matching the filter
  473. query_result = self.query(collection_name, filter)
  474. if query_result and query_result.ids and query_result.ids[0]:
  475. matching_ids = query_result.ids[0]
  476. log.info(f"Found {len(matching_ids)} vectors matching filter, deleting them")
  477. # Delete the matching vectors by ID
  478. self.client.delete_vectors(
  479. vectorBucketName=self.bucket_name,
  480. indexName=collection_name,
  481. keys=matching_ids
  482. )
  483. log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter")
  484. else:
  485. log.warning("No vectors found matching the filter criteria")
  486. else:
  487. log.warning("No IDs or filter provided for deletion")
  488. except Exception as e:
  489. log.error(f"Error deleting vectors from collection '{collection_name}': {e}")
  490. raise
  491. def reset(self) -> None:
  492. """
  493. Reset/clear all vector data. For S3 Vector, this would mean deleting all indexes.
  494. Use with caution as this is destructive.
  495. """
  496. try:
  497. log.warning("Reset called - this will delete all vector indexes in the S3 bucket")
  498. # List all indexes
  499. response = self.client.list_indexes(vectorBucketName=self.bucket_name)
  500. indexes = response.get("indexes", [])
  501. if not indexes:
  502. log.warning("No indexes found to delete")
  503. return
  504. # Delete all indexes
  505. deleted_count = 0
  506. for index in indexes:
  507. index_name = index.get("indexName")
  508. if index_name:
  509. try:
  510. self.client.delete_index(
  511. vectorBucketName=self.bucket_name,
  512. indexName=index_name
  513. )
  514. deleted_count += 1
  515. log.info(f"Deleted index: {index_name}")
  516. except Exception as e:
  517. log.error(f"Error deleting index '{index_name}': {e}")
  518. log.info(f"Reset completed: deleted {deleted_count} indexes")
  519. except Exception as e:
  520. log.error(f"Error during reset: {e}")
  521. raise
  522. def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
  523. """
  524. Check if metadata matches the given filter conditions.
  525. Supports basic equality matching and simple logical operations.
  526. Args:
  527. metadata: The metadata to check
  528. filter: The filter conditions to match against
  529. Returns:
  530. True if metadata matches all filter conditions, False otherwise
  531. """
  532. if not isinstance(metadata, dict) or not isinstance(filter, dict):
  533. return False
  534. # Check each filter condition
  535. for key, expected_value in filter.items():
  536. # Handle special operators
  537. if key.startswith('$'):
  538. if key == '$and':
  539. # All conditions must match
  540. if not isinstance(expected_value, list):
  541. continue
  542. for condition in expected_value:
  543. if not self._matches_filter(metadata, condition):
  544. return False
  545. elif key == '$or':
  546. # At least one condition must match
  547. if not isinstance(expected_value, list):
  548. continue
  549. any_match = False
  550. for condition in expected_value:
  551. if self._matches_filter(metadata, condition):
  552. any_match = True
  553. break
  554. if not any_match:
  555. return False
  556. continue
  557. # Get the actual value from metadata
  558. actual_value = metadata.get(key)
  559. # Handle different types of expected values
  560. if isinstance(expected_value, dict):
  561. # Handle comparison operators
  562. for op, op_value in expected_value.items():
  563. if op == '$eq':
  564. if actual_value != op_value:
  565. return False
  566. elif op == '$ne':
  567. if actual_value == op_value:
  568. return False
  569. elif op == '$in':
  570. if not isinstance(op_value, list) or actual_value not in op_value:
  571. return False
  572. elif op == '$nin':
  573. if isinstance(op_value, list) and actual_value in op_value:
  574. return False
  575. elif op == '$exists':
  576. if bool(op_value) != (key in metadata):
  577. return False
  578. # Add more operators as needed
  579. else:
  580. # Simple equality check
  581. if actual_value != expected_value:
  582. return False
  583. return True