s3vector.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. from open_webui.retrieval.vector.utils import stringify_metadata
  2. from open_webui.retrieval.vector.main import (
  3. VectorDBBase,
  4. VectorItem,
  5. GetResult,
  6. SearchResult,
  7. )
  8. from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
  9. from open_webui.env import SRC_LOG_LEVELS
  10. from typing import List, Optional, Dict, Any, Union
  11. import logging
  12. import boto3
  13. log = logging.getLogger(__name__)
  14. log.setLevel(SRC_LOG_LEVELS["RAG"])
  15. class S3VectorClient(VectorDBBase):
  16. """
  17. AWS S3 Vector integration for Open WebUI Knowledge.
  18. """
  19. def __init__(self):
  20. self.bucket_name = S3_VECTOR_BUCKET_NAME
  21. self.region = S3_VECTOR_REGION
  22. # Simple validation - log warnings instead of raising exceptions
  23. if not self.bucket_name:
  24. log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
  25. if not self.region:
  26. log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
  27. if self.bucket_name and self.region:
  28. try:
  29. self.client = boto3.client("s3vectors", region_name=self.region)
  30. log.info(
  31. f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
  32. )
  33. except Exception as e:
  34. log.error(f"Failed to initialize S3Vector client: {e}")
  35. self.client = None
  36. else:
  37. self.client = None
  38. def _create_index(
  39. self,
  40. index_name: str,
  41. dimension: int,
  42. data_type: str = "float32",
  43. distance_metric: str = "cosine",
  44. ) -> None:
  45. """
  46. Create a new index in the S3 vector bucket for the given collection if it does not exist.
  47. """
  48. if self.has_collection(index_name):
  49. log.debug(f"Index '{index_name}' already exists, skipping creation")
  50. return
  51. try:
  52. self.client.create_index(
  53. vectorBucketName=self.bucket_name,
  54. indexName=index_name,
  55. dataType=data_type,
  56. dimension=dimension,
  57. distanceMetric=distance_metric,
  58. )
  59. log.info(
  60. f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
  61. )
  62. except Exception as e:
  63. log.error(f"Error creating S3 index '{index_name}': {e}")
  64. raise
  65. def _filter_metadata(
  66. self, metadata: Dict[str, Any], item_id: str
  67. ) -> Dict[str, Any]:
  68. """
  69. Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
  70. """
  71. if not isinstance(metadata, dict) or len(metadata) <= 10:
  72. return metadata
  73. # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
  74. important_keys = [
  75. "text", # The actual document content
  76. "file_id", # File ID
  77. "source", # Document source file
  78. "title", # Document title
  79. "page", # Page number
  80. "total_pages", # Total pages in document
  81. "embedding_config", # Embedding configuration
  82. "created_by", # User who created it
  83. "name", # Document name
  84. "hash", # Content hash
  85. ]
  86. filtered_metadata = {}
  87. # First, add important keys if they exist
  88. for key in important_keys:
  89. if key in metadata:
  90. filtered_metadata[key] = metadata[key]
  91. if len(filtered_metadata) >= 10:
  92. break
  93. # If we still have room, add other keys
  94. if len(filtered_metadata) < 10:
  95. for key, value in metadata.items():
  96. if key not in filtered_metadata:
  97. filtered_metadata[key] = value
  98. if len(filtered_metadata) >= 10:
  99. break
  100. log.warning(
  101. f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
  102. )
  103. return filtered_metadata
  104. def has_collection(self, collection_name: str) -> bool:
  105. """
  106. Check if a vector index (collection) exists in the S3 vector bucket.
  107. """
  108. try:
  109. response = self.client.list_indexes(vectorBucketName=self.bucket_name)
  110. indexes = response.get("indexes", [])
  111. return any(idx.get("indexName") == collection_name for idx in indexes)
  112. except Exception as e:
  113. log.error(f"Error listing indexes: {e}")
  114. return False
  115. def delete_collection(self, collection_name: str) -> None:
  116. """
  117. Delete an entire S3 Vector index/collection.
  118. """
  119. if not self.has_collection(collection_name):
  120. log.warning(
  121. f"Collection '{collection_name}' does not exist, nothing to delete"
  122. )
  123. return
  124. try:
  125. log.info(f"Deleting collection '{collection_name}'")
  126. self.client.delete_index(
  127. vectorBucketName=self.bucket_name, indexName=collection_name
  128. )
  129. log.info(f"Successfully deleted collection '{collection_name}'")
  130. except Exception as e:
  131. log.error(f"Error deleting collection '{collection_name}': {e}")
  132. raise
  133. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  134. """
  135. Insert vector items into the S3 Vector index. Create index if it does not exist.
  136. """
  137. if not items:
  138. log.warning("No items to insert")
  139. return
  140. dimension = len(items[0]["vector"])
  141. try:
  142. if not self.has_collection(collection_name):
  143. log.info(f"Index '{collection_name}' does not exist. Creating index.")
  144. self._create_index(
  145. index_name=collection_name,
  146. dimension=dimension,
  147. data_type="float32",
  148. distance_metric="cosine",
  149. )
  150. # Prepare vectors for insertion
  151. vectors = []
  152. for item in items:
  153. # Ensure vector data is in the correct format for S3 Vector API
  154. vector_data = item["vector"]
  155. if isinstance(vector_data, list):
  156. # Convert list to float32 values as required by S3 Vector API
  157. vector_data = [float(x) for x in vector_data]
  158. # Prepare metadata, ensuring the text field is preserved
  159. metadata = item.get("metadata", {}).copy()
  160. # Add the text field to metadata so it's available for retrieval
  161. metadata["text"] = item["text"]
  162. # Convert metadata to string format for consistency
  163. metadata = stringify_metadata(metadata)
  164. # Filter metadata to comply with S3 Vector API limit of 10 keys
  165. metadata = self._filter_metadata(metadata, item["id"])
  166. vectors.append(
  167. {
  168. "key": item["id"],
  169. "data": {"float32": vector_data},
  170. "metadata": metadata,
  171. }
  172. )
  173. # Insert vectors
  174. self.client.put_vectors(
  175. vectorBucketName=self.bucket_name,
  176. indexName=collection_name,
  177. vectors=vectors,
  178. )
  179. log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
  180. except Exception as e:
  181. log.error(f"Error inserting vectors: {e}")
  182. raise
  183. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  184. """
  185. Insert or update vector items in the S3 Vector index. Create index if it does not exist.
  186. """
  187. if not items:
  188. log.warning("No items to upsert")
  189. return
  190. dimension = len(items[0]["vector"])
  191. log.info(f"Upsert dimension: {dimension}")
  192. try:
  193. if not self.has_collection(collection_name):
  194. log.info(
  195. f"Index '{collection_name}' does not exist. Creating index for upsert."
  196. )
  197. self._create_index(
  198. index_name=collection_name,
  199. dimension=dimension,
  200. data_type="float32",
  201. distance_metric="cosine",
  202. )
  203. # Prepare vectors for upsert
  204. vectors = []
  205. for item in items:
  206. # Ensure vector data is in the correct format for S3 Vector API
  207. vector_data = item["vector"]
  208. if isinstance(vector_data, list):
  209. # Convert list to float32 values as required by S3 Vector API
  210. vector_data = [float(x) for x in vector_data]
  211. # Prepare metadata, ensuring the text field is preserved
  212. metadata = item.get("metadata", {}).copy()
  213. # Add the text field to metadata so it's available for retrieval
  214. metadata["text"] = item["text"]
  215. # Convert metadata to string format for consistency
  216. metadata = stringify_metadata(metadata)
  217. # Filter metadata to comply with S3 Vector API limit of 10 keys
  218. metadata = self._filter_metadata(metadata, item["id"])
  219. vectors.append(
  220. {
  221. "key": item["id"],
  222. "data": {"float32": vector_data},
  223. "metadata": metadata,
  224. }
  225. )
  226. # Upsert vectors (using put_vectors for upsert semantics)
  227. log.info(
  228. f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}"
  229. )
  230. self.client.put_vectors(
  231. vectorBucketName=self.bucket_name,
  232. indexName=collection_name,
  233. vectors=vectors,
  234. )
  235. log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
  236. except Exception as e:
  237. log.error(f"Error upserting vectors: {e}")
  238. raise
  239. def search(
  240. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  241. ) -> Optional[SearchResult]:
  242. """
  243. Search for similar vectors in a collection using multiple query vectors.
  244. """
  245. if not self.has_collection(collection_name):
  246. log.warning(f"Collection '{collection_name}' does not exist")
  247. return None
  248. if not vectors:
  249. log.warning("No query vectors provided")
  250. return None
  251. try:
  252. log.info(
  253. f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
  254. )
  255. # Initialize result lists
  256. all_ids = []
  257. all_documents = []
  258. all_metadatas = []
  259. all_distances = []
  260. # Process each query vector
  261. for i, query_vector in enumerate(vectors):
  262. log.debug(f"Processing query vector {i+1}/{len(vectors)}")
  263. # Prepare the query vector in S3 Vector format
  264. query_vector_dict = {"float32": [float(x) for x in query_vector]}
  265. # Call S3 Vector query API
  266. response = self.client.query_vectors(
  267. vectorBucketName=self.bucket_name,
  268. indexName=collection_name,
  269. topK=limit,
  270. queryVector=query_vector_dict,
  271. returnMetadata=True,
  272. returnDistance=True,
  273. )
  274. # Process results for this query
  275. query_ids = []
  276. query_documents = []
  277. query_metadatas = []
  278. query_distances = []
  279. result_vectors = response.get("vectors", [])
  280. for vector in result_vectors:
  281. vector_id = vector.get("key")
  282. vector_metadata = vector.get("metadata", {})
  283. vector_distance = vector.get("distance", 0.0)
  284. # Extract document text from metadata
  285. document_text = ""
  286. if isinstance(vector_metadata, dict):
  287. # Get the text field first (highest priority)
  288. document_text = vector_metadata.get("text")
  289. if not document_text:
  290. # Fallback to other possible text fields
  291. document_text = (
  292. vector_metadata.get("content")
  293. or vector_metadata.get("document")
  294. or vector_id
  295. )
  296. else:
  297. document_text = vector_id
  298. query_ids.append(vector_id)
  299. query_documents.append(document_text)
  300. query_metadatas.append(vector_metadata)
  301. query_distances.append(vector_distance)
  302. # Add this query's results to the overall results
  303. all_ids.append(query_ids)
  304. all_documents.append(query_documents)
  305. all_metadatas.append(query_metadatas)
  306. all_distances.append(query_distances)
  307. log.info(f"Search completed. Found results for {len(all_ids)} queries")
  308. # Return SearchResult format
  309. return SearchResult(
  310. ids=all_ids if all_ids else None,
  311. documents=all_documents if all_documents else None,
  312. metadatas=all_metadatas if all_metadatas else None,
  313. distances=all_distances if all_distances else None,
  314. )
  315. except Exception as e:
  316. log.error(f"Error searching collection '{collection_name}': {str(e)}")
  317. # Handle specific AWS exceptions
  318. if hasattr(e, "response") and "Error" in e.response:
  319. error_code = e.response["Error"]["Code"]
  320. if error_code == "NotFoundException":
  321. log.warning(f"Collection '{collection_name}' not found")
  322. return None
  323. elif error_code == "ValidationException":
  324. log.error(f"Invalid query vector dimensions or parameters")
  325. return None
  326. elif error_code == "AccessDeniedException":
  327. log.error(
  328. f"Access denied for collection '{collection_name}'. Check permissions."
  329. )
  330. return None
  331. raise
  332. def query(
  333. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  334. ) -> Optional[GetResult]:
  335. """
  336. Query vectors from a collection using metadata filter.
  337. """
  338. if not self.has_collection(collection_name):
  339. log.warning(f"Collection '{collection_name}' does not exist")
  340. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  341. if not filter:
  342. log.warning("No filter provided, returning all vectors")
  343. return self.get(collection_name)
  344. try:
  345. log.info(f"Querying collection '{collection_name}' with filter: {filter}")
  346. # For S3 Vector, we need to use list_vectors and then filter results
  347. # Since S3 Vector may not support complex server-side filtering,
  348. # we'll retrieve all vectors and filter client-side
  349. # Get all vectors first
  350. all_vectors_result = self.get(collection_name)
  351. if not all_vectors_result or not all_vectors_result.ids:
  352. log.warning("No vectors found in collection")
  353. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  354. # Extract the lists from the result
  355. all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
  356. all_documents = (
  357. all_vectors_result.documents[0] if all_vectors_result.documents else []
  358. )
  359. all_metadatas = (
  360. all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
  361. )
  362. # Apply client-side filtering
  363. filtered_ids = []
  364. filtered_documents = []
  365. filtered_metadatas = []
  366. for i, metadata in enumerate(all_metadatas):
  367. if self._matches_filter(metadata, filter):
  368. if i < len(all_ids):
  369. filtered_ids.append(all_ids[i])
  370. if i < len(all_documents):
  371. filtered_documents.append(all_documents[i])
  372. filtered_metadatas.append(metadata)
  373. # Apply limit if specified
  374. if limit and len(filtered_ids) >= limit:
  375. break
  376. log.info(
  377. f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
  378. )
  379. # Return GetResult format
  380. if filtered_ids:
  381. return GetResult(
  382. ids=[filtered_ids],
  383. documents=[filtered_documents],
  384. metadatas=[filtered_metadatas],
  385. )
  386. else:
  387. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  388. except Exception as e:
  389. log.error(f"Error querying collection '{collection_name}': {str(e)}")
  390. # Handle specific AWS exceptions
  391. if hasattr(e, "response") and "Error" in e.response:
  392. error_code = e.response["Error"]["Code"]
  393. if error_code == "NotFoundException":
  394. log.warning(f"Collection '{collection_name}' not found")
  395. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  396. elif error_code == "AccessDeniedException":
  397. log.error(
  398. f"Access denied for collection '{collection_name}'. Check permissions."
  399. )
  400. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  401. raise
  402. def get(self, collection_name: str) -> Optional[GetResult]:
  403. """
  404. Retrieve all vectors from a collection.
  405. """
  406. if not self.has_collection(collection_name):
  407. log.warning(f"Collection '{collection_name}' does not exist")
  408. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  409. try:
  410. log.info(f"Retrieving all vectors from collection '{collection_name}'")
  411. # Initialize result lists
  412. all_ids = []
  413. all_documents = []
  414. all_metadatas = []
  415. # Handle pagination
  416. next_token = None
  417. while True:
  418. # Prepare request parameters
  419. request_params = {
  420. "vectorBucketName": self.bucket_name,
  421. "indexName": collection_name,
  422. "returnData": False, # Don't include vector data (not needed for get)
  423. "returnMetadata": True, # Include metadata
  424. "maxResults": 500, # Use reasonable page size
  425. }
  426. if next_token:
  427. request_params["nextToken"] = next_token
  428. # Call S3 Vector API
  429. response = self.client.list_vectors(**request_params)
  430. # Process vectors in this page
  431. vectors = response.get("vectors", [])
  432. for vector in vectors:
  433. vector_id = vector.get("key")
  434. vector_data = vector.get("data", {})
  435. vector_metadata = vector.get("metadata", {})
  436. # Extract the actual vector array
  437. vector_array = vector_data.get("float32", [])
  438. # For documents, we try to extract text from metadata or use the vector ID
  439. document_text = ""
  440. if isinstance(vector_metadata, dict):
  441. # Get the text field first (highest priority)
  442. document_text = vector_metadata.get("text")
  443. if not document_text:
  444. # Fallback to other possible text fields
  445. document_text = (
  446. vector_metadata.get("content")
  447. or vector_metadata.get("document")
  448. or vector_id
  449. )
  450. # Log the actual content for debugging
  451. log.debug(
  452. f"Document text preview (first 200 chars): {str(document_text)[:200]}"
  453. )
  454. else:
  455. document_text = vector_id
  456. all_ids.append(vector_id)
  457. all_documents.append(document_text)
  458. all_metadatas.append(vector_metadata)
  459. # Check if there are more pages
  460. next_token = response.get("nextToken")
  461. if not next_token:
  462. break
  463. log.info(
  464. f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
  465. )
  466. # Return in GetResult format
  467. # The Open WebUI GetResult expects lists of lists, so we wrap each list
  468. if all_ids:
  469. return GetResult(
  470. ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
  471. )
  472. else:
  473. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  474. except Exception as e:
  475. log.error(
  476. f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
  477. )
  478. # Handle specific AWS exceptions
  479. if hasattr(e, "response") and "Error" in e.response:
  480. error_code = e.response["Error"]["Code"]
  481. if error_code == "NotFoundException":
  482. log.warning(f"Collection '{collection_name}' not found")
  483. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  484. elif error_code == "AccessDeniedException":
  485. log.error(
  486. f"Access denied for collection '{collection_name}'. Check permissions."
  487. )
  488. return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
  489. raise
  490. def delete(
  491. self,
  492. collection_name: str,
  493. ids: Optional[List[str]] = None,
  494. filter: Optional[Dict] = None,
  495. ) -> None:
  496. """
  497. Delete vectors by ID or filter from a collection.
  498. """
  499. if not self.has_collection(collection_name):
  500. log.warning(
  501. f"Collection '{collection_name}' does not exist, nothing to delete"
  502. )
  503. return
  504. # Check if this is a knowledge collection (not file-specific)
  505. is_knowledge_collection = not collection_name.startswith("file-")
  506. try:
  507. if ids:
  508. # Delete by specific vector IDs/keys
  509. log.info(
  510. f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
  511. )
  512. self.client.delete_vectors(
  513. vectorBucketName=self.bucket_name,
  514. indexName=collection_name,
  515. keys=ids,
  516. )
  517. log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
  518. elif filter:
  519. # Handle filter-based deletion
  520. log.info(
  521. f"Deleting vectors by filter from collection '{collection_name}': {filter}"
  522. )
  523. # If this is a knowledge collection and we have a file_id filter,
  524. # also clean up the corresponding file-specific collection
  525. if is_knowledge_collection and "file_id" in filter:
  526. file_id = filter["file_id"]
  527. file_collection_name = f"file-{file_id}"
  528. if self.has_collection(file_collection_name):
  529. log.info(
  530. f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
  531. )
  532. self.delete_collection(file_collection_name)
  533. # For the main collection, implement query-then-delete
  534. # First, query to get IDs matching the filter
  535. query_result = self.query(collection_name, filter)
  536. if query_result and query_result.ids and query_result.ids[0]:
  537. matching_ids = query_result.ids[0]
  538. log.info(
  539. f"Found {len(matching_ids)} vectors matching filter, deleting them"
  540. )
  541. # Delete the matching vectors by ID
  542. self.client.delete_vectors(
  543. vectorBucketName=self.bucket_name,
  544. indexName=collection_name,
  545. keys=matching_ids,
  546. )
  547. log.info(
  548. f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
  549. )
  550. else:
  551. log.warning("No vectors found matching the filter criteria")
  552. else:
  553. log.warning("No IDs or filter provided for deletion")
  554. except Exception as e:
  555. log.error(
  556. f"Error deleting vectors from collection '{collection_name}': {e}"
  557. )
  558. raise
  559. def reset(self) -> None:
  560. """
  561. Reset/clear all vector data. For S3 Vector, this deletes all indexes.
  562. """
  563. try:
  564. log.warning(
  565. "Reset called - this will delete all vector indexes in the S3 bucket"
  566. )
  567. # List all indexes
  568. response = self.client.list_indexes(vectorBucketName=self.bucket_name)
  569. indexes = response.get("indexes", [])
  570. if not indexes:
  571. log.warning("No indexes found to delete")
  572. return
  573. # Delete all indexes
  574. deleted_count = 0
  575. for index in indexes:
  576. index_name = index.get("indexName")
  577. if index_name:
  578. try:
  579. self.client.delete_index(
  580. vectorBucketName=self.bucket_name, indexName=index_name
  581. )
  582. deleted_count += 1
  583. log.info(f"Deleted index: {index_name}")
  584. except Exception as e:
  585. log.error(f"Error deleting index '{index_name}': {e}")
  586. log.info(f"Reset completed: deleted {deleted_count} indexes")
  587. except Exception as e:
  588. log.error(f"Error during reset: {e}")
  589. raise
  590. def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
  591. """
  592. Check if metadata matches the given filter conditions.
  593. """
  594. if not isinstance(metadata, dict) or not isinstance(filter, dict):
  595. return False
  596. # Check each filter condition
  597. for key, expected_value in filter.items():
  598. # Handle special operators
  599. if key.startswith("$"):
  600. if key == "$and":
  601. # All conditions must match
  602. if not isinstance(expected_value, list):
  603. continue
  604. for condition in expected_value:
  605. if not self._matches_filter(metadata, condition):
  606. return False
  607. elif key == "$or":
  608. # At least one condition must match
  609. if not isinstance(expected_value, list):
  610. continue
  611. any_match = False
  612. for condition in expected_value:
  613. if self._matches_filter(metadata, condition):
  614. any_match = True
  615. break
  616. if not any_match:
  617. return False
  618. continue
  619. # Get the actual value from metadata
  620. actual_value = metadata.get(key)
  621. # Handle different types of expected values
  622. if isinstance(expected_value, dict):
  623. # Handle comparison operators
  624. for op, op_value in expected_value.items():
  625. if op == "$eq":
  626. if actual_value != op_value:
  627. return False
  628. elif op == "$ne":
  629. if actual_value == op_value:
  630. return False
  631. elif op == "$in":
  632. if (
  633. not isinstance(op_value, list)
  634. or actual_value not in op_value
  635. ):
  636. return False
  637. elif op == "$nin":
  638. if isinstance(op_value, list) and actual_value in op_value:
  639. return False
  640. elif op == "$exists":
  641. if bool(op_value) != (key in metadata):
  642. return False
  643. # Add more operators as needed
  644. else:
  645. # Simple equality check
  646. if actual_value != expected_value:
  647. return False
  648. return True