s3vector.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. from open_webui.retrieval.vector.main import (
  2. VectorDBBase,
  3. VectorItem,
  4. GetResult,
  5. SearchResult,
  6. )
  7. from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
  8. from open_webui.env import SRC_LOG_LEVELS
  9. from typing import List, Optional, Dict, Any, Union
  10. import logging
  11. import boto3
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["RAG"])
  14. class S3VectorClient(VectorDBBase):
  15. """
  16. AWS S3 Vector integration for Open WebUI Knowledge.
  17. """
  18. def __init__(self):
  19. self.bucket_name = S3_VECTOR_BUCKET_NAME
  20. self.region = S3_VECTOR_REGION
  21. # Simple validation - log warnings instead of raising exceptions
  22. if not self.bucket_name:
  23. log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
  24. if not self.region:
  25. log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
  26. if self.bucket_name and self.region:
  27. try:
  28. self.client = boto3.client("s3vectors", region_name=self.region)
  29. log.info(
  30. f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
  31. )
  32. except Exception as e:
  33. log.error(f"Failed to initialize S3Vector client: {e}")
  34. self.client = None
  35. else:
  36. self.client = None
  37. def _create_index(
  38. self,
  39. index_name: str,
  40. dimension: int,
  41. data_type: str = "float32",
  42. distance_metric: str = "cosine",
  43. ) -> None:
  44. """
  45. Create a new index in the S3 vector bucket for the given collection if it does not exist.
  46. """
  47. if self.has_collection(index_name):
  48. log.debug(f"Index '{index_name}' already exists, skipping creation")
  49. return
  50. try:
  51. self.client.create_index(
  52. vectorBucketName=self.bucket_name,
  53. indexName=index_name,
  54. dataType=data_type,
  55. dimension=dimension,
  56. distanceMetric=distance_metric,
  57. )
  58. log.info(
  59. f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
  60. )
  61. except Exception as e:
  62. log.error(f"Error creating S3 index '{index_name}': {e}")
  63. raise
  64. def _filter_metadata(
  65. self, metadata: Dict[str, Any], item_id: str
  66. ) -> Dict[str, Any]:
  67. """
  68. Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
  69. """
  70. if not isinstance(metadata, dict) or len(metadata) <= 10:
  71. return metadata
  72. # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
  73. important_keys = [
  74. "text", # The actual document content
  75. "file_id", # File ID
  76. "source", # Document source file
  77. "title", # Document title
  78. "page", # Page number
  79. "total_pages", # Total pages in document
  80. "embedding_config", # Embedding configuration
  81. "created_by", # User who created it
  82. "name", # Document name
  83. "hash", # Content hash
  84. ]
  85. filtered_metadata = {}
  86. # First, add important keys if they exist
  87. for key in important_keys:
  88. if key in metadata:
  89. filtered_metadata[key] = metadata[key]
  90. if len(filtered_metadata) >= 10:
  91. break
  92. # If we still have room, add other keys
  93. if len(filtered_metadata) < 10:
  94. for key, value in metadata.items():
  95. if key not in filtered_metadata:
  96. filtered_metadata[key] = value
  97. if len(filtered_metadata) >= 10:
  98. break
  99. log.warning(
  100. f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
  101. )
  102. return filtered_metadata
  103. def has_collection(self, collection_name: str) -> bool:
  104. """
  105. Check if a vector index (collection) exists in the S3 vector bucket.
  106. """
  107. try:
  108. response = self.client.list_indexes(vectorBucketName=self.bucket_name)
  109. indexes = response.get("indexes", [])
  110. return any(idx.get("indexName") == collection_name for idx in indexes)
  111. except Exception as e:
  112. log.error(f"Error listing indexes: {e}")
  113. return False
  114. def delete_collection(self, collection_name: str) -> None:
  115. """
  116. Delete an entire S3 Vector index/collection.
  117. """
  118. if not self.has_collection(collection_name):
  119. log.warning(
  120. f"Collection '{collection_name}' does not exist, nothing to delete"
  121. )
  122. return
  123. try:
  124. log.info(f"Deleting collection '{collection_name}'")
  125. self.client.delete_index(
  126. vectorBucketName=self.bucket_name, indexName=collection_name
  127. )
  128. log.info(f"Successfully deleted collection '{collection_name}'")
  129. except Exception as e:
  130. log.error(f"Error deleting collection '{collection_name}': {e}")
  131. raise
  132. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  133. """
  134. Insert vector items into the S3 Vector index. Create index if it does not exist.
  135. """
  136. if not items:
  137. log.warning("No items to insert")
  138. return
  139. dimension = len(items[0]["vector"])
  140. try:
  141. if not self.has_collection(collection_name):
  142. log.info(f"Index '{collection_name}' does not exist. Creating index.")
  143. self._create_index(
  144. index_name=collection_name,
  145. dimension=dimension,
  146. data_type="float32",
  147. distance_metric="cosine",
  148. )
  149. # Prepare vectors for insertion
  150. vectors = []
  151. for item in items:
  152. # Ensure vector data is in the correct format for S3 Vector API
  153. vector_data = item["vector"]
  154. if isinstance(vector_data, list):
  155. # Convert list to float32 values as required by S3 Vector API
  156. vector_data = [float(x) for x in vector_data]
  157. # Prepare metadata, ensuring the text field is preserved
  158. metadata = item.get("metadata", {}).copy()
  159. # Add the text field to metadata so it's available for retrieval
  160. metadata["text"] = item["text"]
  161. # Filter metadata to comply with S3 Vector API limit of 10 keys
  162. metadata = self._filter_metadata(metadata, item["id"])
  163. vectors.append(
  164. {
  165. "key": item["id"],
  166. "data": {"float32": vector_data},
  167. "metadata": metadata,
  168. }
  169. )
  170. # Insert vectors
  171. self.client.put_vectors(
  172. vectorBucketName=self.bucket_name,
  173. indexName=collection_name,
  174. vectors=vectors,
  175. )
  176. log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
  177. except Exception as e:
  178. log.error(f"Error inserting vectors: {e}")
  179. raise
  180. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  181. """
  182. Insert or update vector items in the S3 Vector index. Create index if it does not exist.
  183. """
  184. if not items:
  185. log.warning("No items to upsert")
  186. return
  187. dimension = len(items[0]["vector"])
  188. log.info(f"Upsert dimension: {dimension}")
  189. try:
  190. if not self.has_collection(collection_name):
  191. log.info(
  192. f"Index '{collection_name}' does not exist. Creating index for upsert."
  193. )
  194. self._create_index(
  195. index_name=collection_name,
  196. dimension=dimension,
  197. data_type="float32",
  198. distance_metric="cosine",
  199. )
  200. # Prepare vectors for upsert
  201. vectors = []
  202. for item in items:
  203. # Ensure vector data is in the correct format for S3 Vector API
  204. vector_data = item["vector"]
  205. if isinstance(vector_data, list):
  206. # Convert list to float32 values as required by S3 Vector API
  207. vector_data = [float(x) for x in vector_data]
  208. # Prepare metadata, ensuring the text field is preserved
  209. metadata = item.get("metadata", {}).copy()
  210. # Add the text field to metadata so it's available for retrieval
  211. metadata["text"] = item["text"]
  212. # Filter metadata to comply with S3 Vector API limit of 10 keys
  213. metadata = self._filter_metadata(metadata, item["id"])
  214. vectors.append(
  215. {
  216. "key": item["id"],
  217. "data": {"float32": vector_data},
  218. "metadata": metadata,
  219. }
  220. )
  221. # Upsert vectors (using put_vectors for upsert semantics)
  222. log.info(
  223. f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}"
  224. )
  225. self.client.put_vectors(
  226. vectorBucketName=self.bucket_name,
  227. indexName=collection_name,
  228. vectors=vectors,
  229. )
  230. log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
  231. except Exception as e:
  232. log.error(f"Error upserting vectors: {e}")
  233. raise
  234. def search(
  235. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  236. ) -> Optional[SearchResult]:
  237. """
  238. Search for similar vectors in a collection using multiple query vectors.
  239. """
  240. if not self.has_collection(collection_name):
  241. log.warning(f"Collection '{collection_name}' does not exist")
  242. return None
  243. if not vectors:
  244. log.warning("No query vectors provided")
  245. return None
  246. try:
  247. log.info(
  248. f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
  249. )
  250. # Initialize result lists
  251. all_ids = []
  252. all_documents = []
  253. all_metadatas = []
  254. all_distances = []
  255. # Process each query vector
  256. for i, query_vector in enumerate(vectors):
  257. log.debug(f"Processing query vector {i+1}/{len(vectors)}")
  258. # Prepare the query vector in S3 Vector format
  259. query_vector_dict = {"float32": [float(x) for x in query_vector]}
  260. # Call S3 Vector query API
  261. response = self.client.query_vectors(
  262. vectorBucketName=self.bucket_name,
  263. indexName=collection_name,
  264. topK=limit,
  265. queryVector=query_vector_dict,
  266. returnMetadata=True,
  267. returnDistance=True,
  268. )
  269. # Process results for this query
  270. query_ids = []
  271. query_documents = []
  272. query_metadatas = []
  273. query_distances = []
  274. result_vectors = response.get("vectors", [])
  275. for vector in result_vectors:
  276. vector_id = vector.get("key")
  277. vector_metadata = vector.get("metadata", {})
  278. vector_distance = vector.get("distance", 0.0)
  279. # Extract document text from metadata
  280. document_text = ""
  281. if isinstance(vector_metadata, dict):
  282. # Get the text field first (highest priority)
  283. document_text = vector_metadata.get("text")
  284. if not document_text:
  285. # Fallback to other possible text fields
  286. document_text = (
  287. vector_metadata.get("content")
  288. or vector_metadata.get("document")
  289. or vector_id
  290. )
  291. else:
  292. document_text = vector_id
  293. query_ids.append(vector_id)
  294. query_documents.append(document_text)
  295. query_metadatas.append(vector_metadata)
  296. query_distances.append(vector_distance)
  297. # Add this query's results to the overall results
  298. all_ids.append(query_ids)
  299. all_documents.append(query_documents)
  300. all_metadatas.append(query_metadatas)
  301. all_distances.append(query_distances)
  302. log.info(f"Search completed. Found results for {len(all_ids)} queries")
  303. # Return SearchResult format
  304. return SearchResult(
  305. ids=all_ids if all_ids else None,
  306. documents=all_documents if all_documents else None,
  307. metadatas=all_metadatas if all_metadatas else None,
  308. distances=all_distances if all_distances else None,
  309. )
  310. except Exception as e:
  311. log.error(f"Error searching collection '{collection_name}': {str(e)}")
  312. # Handle specific AWS exceptions
  313. if hasattr(e, "response") and "Error" in e.response:
  314. error_code = e.response["Error"]["Code"]
  315. if error_code == "NotFoundException":
  316. log.warning(f"Collection '{collection_name}' not found")
  317. return None
  318. elif error_code == "ValidationException":
  319. log.error(f"Invalid query vector dimensions or parameters")
  320. return None
  321. elif error_code == "AccessDeniedException":
  322. log.error(
  323. f"Access denied for collection '{collection_name}'. Check permissions."
  324. )
  325. return None
  326. raise
  327. def query(
  328. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  329. ) -> Optional[GetResult]:
  330. """
  331. Query vectors from a collection using metadata filter.
  332. """
  333. if not self.has_collection(collection_name):
  334. log.warning(f"Collection '{collection_name}' does not exist")
  335. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  336. if not filter:
  337. log.warning("No filter provided, returning all vectors")
  338. return self.get(collection_name)
  339. try:
  340. log.info(f"Querying collection '{collection_name}' with filter: {filter}")
  341. # For S3 Vector, we need to use list_vectors and then filter results
  342. # Since S3 Vector may not support complex server-side filtering,
  343. # we'll retrieve all vectors and filter client-side
  344. # Get all vectors first
  345. all_vectors_result = self.get(collection_name)
  346. if not all_vectors_result or not all_vectors_result.ids:
  347. log.warning("No vectors found in collection")
  348. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  349. # Extract the lists from the result
  350. all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
  351. all_documents = (
  352. all_vectors_result.documents[0] if all_vectors_result.documents else []
  353. )
  354. all_metadatas = (
  355. all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
  356. )
  357. # Apply client-side filtering
  358. filtered_ids = []
  359. filtered_documents = []
  360. filtered_metadatas = []
  361. for i, metadata in enumerate(all_metadatas):
  362. if self._matches_filter(metadata, filter):
  363. if i < len(all_ids):
  364. filtered_ids.append(all_ids[i])
  365. if i < len(all_documents):
  366. filtered_documents.append(all_documents[i])
  367. filtered_metadatas.append(metadata)
  368. # Apply limit if specified
  369. if limit and len(filtered_ids) >= limit:
  370. break
  371. log.info(
  372. f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
  373. )
  374. # Return GetResult format
  375. if filtered_ids:
  376. return GetResult(
  377. ids=[filtered_ids],
  378. documents=[filtered_documents],
  379. metadatas=[filtered_metadatas],
  380. )
  381. else:
  382. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  383. except Exception as e:
  384. log.error(f"Error querying collection '{collection_name}': {str(e)}")
  385. # Handle specific AWS exceptions
  386. if hasattr(e, "response") and "Error" in e.response:
  387. error_code = e.response["Error"]["Code"]
  388. if error_code == "NotFoundException":
  389. log.warning(f"Collection '{collection_name}' not found")
  390. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  391. elif error_code == "AccessDeniedException":
  392. log.error(
  393. f"Access denied for collection '{collection_name}'. Check permissions."
  394. )
  395. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  396. raise
  397. def get(self, collection_name: str) -> Optional[GetResult]:
  398. """
  399. Retrieve all vectors from a collection.
  400. """
  401. if not self.has_collection(collection_name):
  402. log.warning(f"Collection '{collection_name}' does not exist")
  403. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  404. try:
  405. log.info(f"Retrieving all vectors from collection '{collection_name}'")
  406. # Initialize result lists
  407. all_ids = []
  408. all_documents = []
  409. all_metadatas = []
  410. # Handle pagination
  411. next_token = None
  412. while True:
  413. # Prepare request parameters
  414. request_params = {
  415. "vectorBucketName": self.bucket_name,
  416. "indexName": collection_name,
  417. "returnData": False, # Don't include vector data (not needed for get)
  418. "returnMetadata": True, # Include metadata
  419. "maxResults": 500, # Use reasonable page size
  420. }
  421. if next_token:
  422. request_params["nextToken"] = next_token
  423. # Call S3 Vector API
  424. response = self.client.list_vectors(**request_params)
  425. # Process vectors in this page
  426. vectors = response.get("vectors", [])
  427. for vector in vectors:
  428. vector_id = vector.get("key")
  429. vector_data = vector.get("data", {})
  430. vector_metadata = vector.get("metadata", {})
  431. # Extract the actual vector array
  432. vector_array = vector_data.get("float32", [])
  433. # For documents, we try to extract text from metadata or use the vector ID
  434. document_text = ""
  435. if isinstance(vector_metadata, dict):
  436. # Get the text field first (highest priority)
  437. document_text = vector_metadata.get("text")
  438. if not document_text:
  439. # Fallback to other possible text fields
  440. document_text = (
  441. vector_metadata.get("content")
  442. or vector_metadata.get("document")
  443. or vector_id
  444. )
  445. # Log the actual content for debugging
  446. log.debug(
  447. f"Document text preview (first 200 chars): {str(document_text)[:200]}"
  448. )
  449. else:
  450. document_text = vector_id
  451. all_ids.append(vector_id)
  452. all_documents.append(document_text)
  453. all_metadatas.append(vector_metadata)
  454. # Check if there are more pages
  455. next_token = response.get("nextToken")
  456. if not next_token:
  457. break
  458. log.info(
  459. f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
  460. )
  461. # Return in GetResult format
  462. # The Open WebUI GetResult expects lists of lists, so we wrap each list
  463. if all_ids:
  464. return GetResult(
  465. ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
  466. )
  467. else:
  468. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  469. except Exception as e:
  470. log.error(
  471. f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
  472. )
  473. # Handle specific AWS exceptions
  474. if hasattr(e, "response") and "Error" in e.response:
  475. error_code = e.response["Error"]["Code"]
  476. if error_code == "NotFoundException":
  477. log.warning(f"Collection '{collection_name}' not found")
  478. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  479. elif error_code == "AccessDeniedException":
  480. log.error(
  481. f"Access denied for collection '{collection_name}'. Check permissions."
  482. )
  483. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  484. raise
  485. def delete(
  486. self,
  487. collection_name: str,
  488. ids: Optional[List[str]] = None,
  489. filter: Optional[Dict] = None,
  490. ) -> None:
  491. """
  492. Delete vectors by ID or filter from a collection.
  493. """
  494. if not self.has_collection(collection_name):
  495. log.warning(
  496. f"Collection '{collection_name}' does not exist, nothing to delete"
  497. )
  498. return
  499. # Check if this is a knowledge collection (not file-specific)
  500. is_knowledge_collection = not collection_name.startswith("file-")
  501. try:
  502. if ids:
  503. # Delete by specific vector IDs/keys
  504. log.info(
  505. f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
  506. )
  507. self.client.delete_vectors(
  508. vectorBucketName=self.bucket_name,
  509. indexName=collection_name,
  510. keys=ids,
  511. )
  512. log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
  513. elif filter:
  514. # Handle filter-based deletion
  515. log.info(
  516. f"Deleting vectors by filter from collection '{collection_name}': {filter}"
  517. )
  518. # If this is a knowledge collection and we have a file_id filter,
  519. # also clean up the corresponding file-specific collection
  520. if is_knowledge_collection and "file_id" in filter:
  521. file_id = filter["file_id"]
  522. file_collection_name = f"file-{file_id}"
  523. if self.has_collection(file_collection_name):
  524. log.info(
  525. f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
  526. )
  527. self.delete_collection(file_collection_name)
  528. # For the main collection, implement query-then-delete
  529. # First, query to get IDs matching the filter
  530. query_result = self.query(collection_name, filter)
  531. if query_result and query_result.ids and query_result.ids[0]:
  532. matching_ids = query_result.ids[0]
  533. log.info(
  534. f"Found {len(matching_ids)} vectors matching filter, deleting them"
  535. )
  536. # Delete the matching vectors by ID
  537. self.client.delete_vectors(
  538. vectorBucketName=self.bucket_name,
  539. indexName=collection_name,
  540. keys=matching_ids,
  541. )
  542. log.info(
  543. f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
  544. )
  545. else:
  546. log.warning("No vectors found matching the filter criteria")
  547. else:
  548. log.warning("No IDs or filter provided for deletion")
  549. except Exception as e:
  550. log.error(
  551. f"Error deleting vectors from collection '{collection_name}': {e}"
  552. )
  553. raise
  554. def reset(self) -> None:
  555. """
  556. Reset/clear all vector data. For S3 Vector, this deletes all indexes.
  557. """
  558. try:
  559. log.warning(
  560. "Reset called - this will delete all vector indexes in the S3 bucket"
  561. )
  562. # List all indexes
  563. response = self.client.list_indexes(vectorBucketName=self.bucket_name)
  564. indexes = response.get("indexes", [])
  565. if not indexes:
  566. log.warning("No indexes found to delete")
  567. return
  568. # Delete all indexes
  569. deleted_count = 0
  570. for index in indexes:
  571. index_name = index.get("indexName")
  572. if index_name:
  573. try:
  574. self.client.delete_index(
  575. vectorBucketName=self.bucket_name, indexName=index_name
  576. )
  577. deleted_count += 1
  578. log.info(f"Deleted index: {index_name}")
  579. except Exception as e:
  580. log.error(f"Error deleting index '{index_name}': {e}")
  581. log.info(f"Reset completed: deleted {deleted_count} indexes")
  582. except Exception as e:
  583. log.error(f"Error during reset: {e}")
  584. raise
  585. def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
  586. """
  587. Check if metadata matches the given filter conditions.
  588. """
  589. if not isinstance(metadata, dict) or not isinstance(filter, dict):
  590. return False
  591. # Check each filter condition
  592. for key, expected_value in filter.items():
  593. # Handle special operators
  594. if key.startswith("$"):
  595. if key == "$and":
  596. # All conditions must match
  597. if not isinstance(expected_value, list):
  598. continue
  599. for condition in expected_value:
  600. if not self._matches_filter(metadata, condition):
  601. return False
  602. elif key == "$or":
  603. # At least one condition must match
  604. if not isinstance(expected_value, list):
  605. continue
  606. any_match = False
  607. for condition in expected_value:
  608. if self._matches_filter(metadata, condition):
  609. any_match = True
  610. break
  611. if not any_match:
  612. return False
  613. continue
  614. # Get the actual value from metadata
  615. actual_value = metadata.get(key)
  616. # Handle different types of expected values
  617. if isinstance(expected_value, dict):
  618. # Handle comparison operators
  619. for op, op_value in expected_value.items():
  620. if op == "$eq":
  621. if actual_value != op_value:
  622. return False
  623. elif op == "$ne":
  624. if actual_value == op_value:
  625. return False
  626. elif op == "$in":
  627. if (
  628. not isinstance(op_value, list)
  629. or actual_value not in op_value
  630. ):
  631. return False
  632. elif op == "$nin":
  633. if isinstance(op_value, list) and actual_value in op_value:
  634. return False
  635. elif op == "$exists":
  636. if bool(op_value) != (key in metadata):
  637. return False
  638. # Add more operators as needed
  639. else:
  640. # Simple equality check
  641. if actual_value != expected_value:
  642. return False
  643. return True