Răsfoiți Sursa

chore: run formatting

0xThresh.eth 2 luni în urmă
părinte
comite
860f3b3cab

+ 285 - 205
backend/open_webui/retrieval/vector/dbs/s3vector.py

@@ -1,4 +1,9 @@
-from open_webui.retrieval.vector.main import VectorDBBase, VectorItem, GetResult, SearchResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    GetResult,
+    SearchResult,
+)
 from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
 from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from typing import List, Optional, Dict, Any, Union
 from typing import List, Optional, Dict, Any, Union
@@ -8,39 +13,48 @@ import boto3
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
+
 class S3VectorClient(VectorDBBase):
 class S3VectorClient(VectorDBBase):
     """
     """
     AWS S3 Vector integration for Open WebUI Knowledge.
     AWS S3 Vector integration for Open WebUI Knowledge.
     """
     """
-    
+
     def __init__(self):
     def __init__(self):
         self.bucket_name = S3_VECTOR_BUCKET_NAME
         self.bucket_name = S3_VECTOR_BUCKET_NAME
         self.region = S3_VECTOR_REGION
         self.region = S3_VECTOR_REGION
-        
+
         # Simple validation - log warnings instead of raising exceptions
         # Simple validation - log warnings instead of raising exceptions
         if not self.bucket_name:
         if not self.bucket_name:
             log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
             log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
         if not self.region:
         if not self.region:
             log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
             log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
-            
+
         if self.bucket_name and self.region:
         if self.bucket_name and self.region:
             try:
             try:
                 self.client = boto3.client("s3vectors", region_name=self.region)
                 self.client = boto3.client("s3vectors", region_name=self.region)
-                log.info(f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'")
+                log.info(
+                    f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
+                )
             except Exception as e:
             except Exception as e:
                 log.error(f"Failed to initialize S3Vector client: {e}")
                 log.error(f"Failed to initialize S3Vector client: {e}")
                 self.client = None
                 self.client = None
         else:
         else:
             self.client = None
             self.client = None
 
 
-    def _create_index(self, index_name: str, dimension: int, data_type: str = "float32", distance_metric: str = "cosine") -> None:
+    def _create_index(
+        self,
+        index_name: str,
+        dimension: int,
+        data_type: str = "float32",
+        distance_metric: str = "cosine",
+    ) -> None:
         """
         """
         Create a new index in the S3 vector bucket for the given collection if it does not exist.
         Create a new index in the S3 vector bucket for the given collection if it does not exist.
         """
         """
         if self.has_collection(index_name):
         if self.has_collection(index_name):
             log.debug(f"Index '{index_name}' already exists, skipping creation")
             log.debug(f"Index '{index_name}' already exists, skipping creation")
             return
             return
-            
+
         try:
         try:
             self.client.create_index(
             self.client.create_index(
                 vectorBucketName=self.bucket_name,
                 vectorBucketName=self.bucket_name,
@@ -49,40 +63,44 @@ class S3VectorClient(VectorDBBase):
                 dimension=dimension,
                 dimension=dimension,
                 distanceMetric=distance_metric,
                 distanceMetric=distance_metric,
             )
             )
-            log.info(f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})")
+            log.info(
+                f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
+            )
         except Exception as e:
         except Exception as e:
             log.error(f"Error creating S3 index '{index_name}': {e}")
             log.error(f"Error creating S3 index '{index_name}': {e}")
             raise
             raise
 
 
-    def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]:
+    def _filter_metadata(
+        self, metadata: Dict[str, Any], item_id: str
+    ) -> Dict[str, Any]:
         """
         """
         Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
         Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
         """
         """
         if not isinstance(metadata, dict) or len(metadata) <= 10:
         if not isinstance(metadata, dict) or len(metadata) <= 10:
             return metadata
             return metadata
-            
-        # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata            
+
+        # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
         important_keys = [
         important_keys = [
-            'text',             # The actual document content
-            'file_id',          # File ID
-            'source',           # Document source file
-            'title',            # Document title
-            'page',             # Page number
-            'total_pages',      # Total pages in document
-            'embedding_config', # Embedding configuration
-            'created_by',       # User who created it
-            'name',             # Document name
-            'hash',             # Content hash
+            "text",  # The actual document content
+            "file_id",  # File ID
+            "source",  # Document source file
+            "title",  # Document title
+            "page",  # Page number
+            "total_pages",  # Total pages in document
+            "embedding_config",  # Embedding configuration
+            "created_by",  # User who created it
+            "name",  # Document name
+            "hash",  # Content hash
         ]
         ]
         filtered_metadata = {}
         filtered_metadata = {}
-        
+
         # First, add important keys if they exist
         # First, add important keys if they exist
         for key in important_keys:
         for key in important_keys:
             if key in metadata:
             if key in metadata:
                 filtered_metadata[key] = metadata[key]
                 filtered_metadata[key] = metadata[key]
             if len(filtered_metadata) >= 10:
             if len(filtered_metadata) >= 10:
                 break
                 break
-                
+
         # If we still have room, add other keys
         # If we still have room, add other keys
         if len(filtered_metadata) < 10:
         if len(filtered_metadata) < 10:
             for key, value in metadata.items():
             for key, value in metadata.items():
@@ -90,15 +108,17 @@ class S3VectorClient(VectorDBBase):
                     filtered_metadata[key] = value
                     filtered_metadata[key] = value
                     if len(filtered_metadata) >= 10:
                     if len(filtered_metadata) >= 10:
                         break
                         break
-                        
-        log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys")
+
+        log.warning(
+            f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
+        )
         return filtered_metadata
         return filtered_metadata
 
 
     def has_collection(self, collection_name: str) -> bool:
     def has_collection(self, collection_name: str) -> bool:
         """
         """
         Check if a vector index (collection) exists in the S3 vector bucket.
         Check if a vector index (collection) exists in the S3 vector bucket.
         """
         """
-            
+
         try:
         try:
             response = self.client.list_indexes(vectorBucketName=self.bucket_name)
             response = self.client.list_indexes(vectorBucketName=self.bucket_name)
             indexes = response.get("indexes", [])
             indexes = response.get("indexes", [])
@@ -106,21 +126,22 @@ class S3VectorClient(VectorDBBase):
         except Exception as e:
         except Exception as e:
             log.error(f"Error listing indexes: {e}")
             log.error(f"Error listing indexes: {e}")
             return False
             return False
-            
+
     def delete_collection(self, collection_name: str) -> None:
     def delete_collection(self, collection_name: str) -> None:
         """
         """
         Delete an entire S3 Vector index/collection.
         Delete an entire S3 Vector index/collection.
         """
         """
-            
+
         if not self.has_collection(collection_name):
         if not self.has_collection(collection_name):
-            log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
+            log.warning(
+                f"Collection '{collection_name}' does not exist, nothing to delete"
+            )
             return
             return
-            
+
         try:
         try:
             log.info(f"Deleting collection '{collection_name}'")
             log.info(f"Deleting collection '{collection_name}'")
             self.client.delete_index(
             self.client.delete_index(
-                vectorBucketName=self.bucket_name,
-                indexName=collection_name
+                vectorBucketName=self.bucket_name, indexName=collection_name
             )
             )
             log.info(f"Successfully deleted collection '{collection_name}'")
             log.info(f"Successfully deleted collection '{collection_name}'")
         except Exception as e:
         except Exception as e:
@@ -134,9 +155,9 @@ class S3VectorClient(VectorDBBase):
         if not items:
         if not items:
             log.warning("No items to insert")
             log.warning("No items to insert")
             return
             return
-            
+
         dimension = len(items[0]["vector"])
         dimension = len(items[0]["vector"])
-        
+
         try:
         try:
             if not self.has_collection(collection_name):
             if not self.has_collection(collection_name):
                 log.info(f"Index '{collection_name}' does not exist. Creating index.")
                 log.info(f"Index '{collection_name}' does not exist. Creating index.")
@@ -146,7 +167,7 @@ class S3VectorClient(VectorDBBase):
                     data_type="float32",
                     data_type="float32",
                     distance_metric="cosine",
                     distance_metric="cosine",
                 )
                 )
-            
+
             # Prepare vectors for insertion
             # Prepare vectors for insertion
             vectors = []
             vectors = []
             for item in items:
             for item in items:
@@ -155,28 +176,28 @@ class S3VectorClient(VectorDBBase):
                 if isinstance(vector_data, list):
                 if isinstance(vector_data, list):
                     # Convert list to float32 values as required by S3 Vector API
                     # Convert list to float32 values as required by S3 Vector API
                     vector_data = [float(x) for x in vector_data]
                     vector_data = [float(x) for x in vector_data]
-                
+
                 # Prepare metadata, ensuring the text field is preserved
                 # Prepare metadata, ensuring the text field is preserved
                 metadata = item.get("metadata", {}).copy()
                 metadata = item.get("metadata", {}).copy()
-                
+
                 # Add the text field to metadata so it's available for retrieval
                 # Add the text field to metadata so it's available for retrieval
                 metadata["text"] = item["text"]
                 metadata["text"] = item["text"]
-                
+
                 # Filter metadata to comply with S3 Vector API limit of 10 keys
                 # Filter metadata to comply with S3 Vector API limit of 10 keys
                 metadata = self._filter_metadata(metadata, item["id"])
                 metadata = self._filter_metadata(metadata, item["id"])
-                
-                vectors.append({
-                    "key": item["id"],
-                    "data": {
-                        "float32": vector_data
-                    },
-                    "metadata": metadata
-                })
+
+                vectors.append(
+                    {
+                        "key": item["id"],
+                        "data": {"float32": vector_data},
+                        "metadata": metadata,
+                    }
+                )
             # Insert vectors
             # Insert vectors
             self.client.put_vectors(
             self.client.put_vectors(
                 vectorBucketName=self.bucket_name,
                 vectorBucketName=self.bucket_name,
                 indexName=collection_name,
                 indexName=collection_name,
-                vectors=vectors
+                vectors=vectors,
             )
             )
             log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
             log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
         except Exception as e:
         except Exception as e:
@@ -190,20 +211,22 @@ class S3VectorClient(VectorDBBase):
         if not items:
         if not items:
             log.warning("No items to upsert")
             log.warning("No items to upsert")
             return
             return
-            
+
         dimension = len(items[0]["vector"])
         dimension = len(items[0]["vector"])
         log.info(f"Upsert dimension: {dimension}")
         log.info(f"Upsert dimension: {dimension}")
-        
+
         try:
         try:
             if not self.has_collection(collection_name):
             if not self.has_collection(collection_name):
-                log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.")
+                log.info(
+                    f"Index '{collection_name}' does not exist. Creating index for upsert."
+                )
                 self._create_index(
                 self._create_index(
                     index_name=collection_name,
                     index_name=collection_name,
                     dimension=dimension,
                     dimension=dimension,
                     data_type="float32",
                     data_type="float32",
                     distance_metric="cosine",
                     distance_metric="cosine",
                 )
                 )
-            
+
             # Prepare vectors for upsert
             # Prepare vectors for upsert
             vectors = []
             vectors = []
             for item in items:
             for item in items:
@@ -212,65 +235,69 @@ class S3VectorClient(VectorDBBase):
                 if isinstance(vector_data, list):
                 if isinstance(vector_data, list):
                     # Convert list to float32 values as required by S3 Vector API
                     # Convert list to float32 values as required by S3 Vector API
                     vector_data = [float(x) for x in vector_data]
                     vector_data = [float(x) for x in vector_data]
-                
+
                 # Prepare metadata, ensuring the text field is preserved
                 # Prepare metadata, ensuring the text field is preserved
                 metadata = item.get("metadata", {}).copy()
                 metadata = item.get("metadata", {}).copy()
                 # Add the text field to metadata so it's available for retrieval
                 # Add the text field to metadata so it's available for retrieval
                 metadata["text"] = item["text"]
                 metadata["text"] = item["text"]
-                
+
                 # Filter metadata to comply with S3 Vector API limit of 10 keys
                 # Filter metadata to comply with S3 Vector API limit of 10 keys
                 metadata = self._filter_metadata(metadata, item["id"])
                 metadata = self._filter_metadata(metadata, item["id"])
-                
-                vectors.append({
-                    "key": item["id"],
-                    "data": {
-                        "float32": vector_data
-                    },
-                    "metadata": metadata
-                })
+
+                vectors.append(
+                    {
+                        "key": item["id"],
+                        "data": {"float32": vector_data},
+                        "metadata": metadata,
+                    }
+                )
             # Upsert vectors (using put_vectors for upsert semantics)
             # Upsert vectors (using put_vectors for upsert semantics)
-            log.info(f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}")
+            log.info(
+                f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}"
+            )
             self.client.put_vectors(
             self.client.put_vectors(
                 vectorBucketName=self.bucket_name,
                 vectorBucketName=self.bucket_name,
                 indexName=collection_name,
                 indexName=collection_name,
-                vectors=vectors
+                vectors=vectors,
             )
             )
             log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
             log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
         except Exception as e:
         except Exception as e:
             log.error(f"Error upserting vectors: {e}")
             log.error(f"Error upserting vectors: {e}")
             raise
             raise
 
 
-    def search(self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int) -> Optional[SearchResult]:
+    def search(
+        self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
+    ) -> Optional[SearchResult]:
         """
         """
         Search for similar vectors in a collection using multiple query vectors.
         Search for similar vectors in a collection using multiple query vectors.
         """
         """
-            
+
         if not self.has_collection(collection_name):
         if not self.has_collection(collection_name):
             log.warning(f"Collection '{collection_name}' does not exist")
             log.warning(f"Collection '{collection_name}' does not exist")
             return None
             return None
-            
+
         if not vectors:
         if not vectors:
             log.warning("No query vectors provided")
             log.warning("No query vectors provided")
             return None
             return None
-            
+
         try:
         try:
-            log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}")
-            
+            log.info(
+                f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
+            )
+
             # Initialize result lists
             # Initialize result lists
             all_ids = []
             all_ids = []
             all_documents = []
             all_documents = []
             all_metadatas = []
             all_metadatas = []
             all_distances = []
             all_distances = []
-            
+
             # Process each query vector
             # Process each query vector
             for i, query_vector in enumerate(vectors):
             for i, query_vector in enumerate(vectors):
                 log.debug(f"Processing query vector {i+1}/{len(vectors)}")
                 log.debug(f"Processing query vector {i+1}/{len(vectors)}")
-                
+
                 # Prepare the query vector in S3 Vector format
                 # Prepare the query vector in S3 Vector format
-                query_vector_dict = {
-                    'float32': [float(x) for x in query_vector]
-                }
-                
+                query_vector_dict = {"float32": [float(x) for x in query_vector]}
+
                 # Call S3 Vector query API
                 # Call S3 Vector query API
                 response = self.client.query_vectors(
                 response = self.client.query_vectors(
                     vectorBucketName=self.bucket_name,
                     vectorBucketName=self.bucket_name,
@@ -278,109 +305,119 @@ class S3VectorClient(VectorDBBase):
                     topK=limit,
                     topK=limit,
                     queryVector=query_vector_dict,
                     queryVector=query_vector_dict,
                     returnMetadata=True,
                     returnMetadata=True,
-                    returnDistance=True
+                    returnDistance=True,
                 )
                 )
-                
+
                 # Process results for this query
                 # Process results for this query
                 query_ids = []
                 query_ids = []
                 query_documents = []
                 query_documents = []
                 query_metadatas = []
                 query_metadatas = []
                 query_distances = []
                 query_distances = []
-                
-                result_vectors = response.get('vectors', [])
-                
+
+                result_vectors = response.get("vectors", [])
+
                 for vector in result_vectors:
                 for vector in result_vectors:
-                    vector_id = vector.get('key')
-                    vector_metadata = vector.get('metadata', {})
-                    vector_distance = vector.get('distance', 0.0)
-                    
+                    vector_id = vector.get("key")
+                    vector_metadata = vector.get("metadata", {})
+                    vector_distance = vector.get("distance", 0.0)
+
                     # Extract document text from metadata
                     # Extract document text from metadata
                     document_text = ""
                     document_text = ""
                     if isinstance(vector_metadata, dict):
                     if isinstance(vector_metadata, dict):
                         # Get the text field first (highest priority)
                         # Get the text field first (highest priority)
-                        document_text = vector_metadata.get('text')
+                        document_text = vector_metadata.get("text")
                         if not document_text:
                         if not document_text:
                             # Fallback to other possible text fields
                             # Fallback to other possible text fields
-                            document_text = (vector_metadata.get('content') or 
-                                           vector_metadata.get('document') or 
-                                           vector_id)
+                            document_text = (
+                                vector_metadata.get("content")
+                                or vector_metadata.get("document")
+                                or vector_id
+                            )
                     else:
                     else:
                         document_text = vector_id
                         document_text = vector_id
-                    
+
                     query_ids.append(vector_id)
                     query_ids.append(vector_id)
                     query_documents.append(document_text)
                     query_documents.append(document_text)
                     query_metadatas.append(vector_metadata)
                     query_metadatas.append(vector_metadata)
                     query_distances.append(vector_distance)
                     query_distances.append(vector_distance)
-                
+
                 # Add this query's results to the overall results
                 # Add this query's results to the overall results
                 all_ids.append(query_ids)
                 all_ids.append(query_ids)
                 all_documents.append(query_documents)
                 all_documents.append(query_documents)
                 all_metadatas.append(query_metadatas)
                 all_metadatas.append(query_metadatas)
                 all_distances.append(query_distances)
                 all_distances.append(query_distances)
-            
+
             log.info(f"Search completed. Found results for {len(all_ids)} queries")
             log.info(f"Search completed. Found results for {len(all_ids)} queries")
-            
+
             # Return SearchResult format
             # Return SearchResult format
             return SearchResult(
             return SearchResult(
                 ids=all_ids if all_ids else None,
                 ids=all_ids if all_ids else None,
                 documents=all_documents if all_documents else None,
                 documents=all_documents if all_documents else None,
                 metadatas=all_metadatas if all_metadatas else None,
                 metadatas=all_metadatas if all_metadatas else None,
-                distances=all_distances if all_distances else None
+                distances=all_distances if all_distances else None,
             )
             )
-            
+
         except Exception as e:
         except Exception as e:
             log.error(f"Error searching collection '{collection_name}': {str(e)}")
             log.error(f"Error searching collection '{collection_name}': {str(e)}")
             # Handle specific AWS exceptions
             # Handle specific AWS exceptions
-            if hasattr(e, 'response') and 'Error' in e.response:
-                error_code = e.response['Error']['Code']
-                if error_code == 'NotFoundException':
+            if hasattr(e, "response") and "Error" in e.response:
+                error_code = e.response["Error"]["Code"]
+                if error_code == "NotFoundException":
                     log.warning(f"Collection '{collection_name}' not found")
                     log.warning(f"Collection '{collection_name}' not found")
                     return None
                     return None
-                elif error_code == 'ValidationException':
+                elif error_code == "ValidationException":
                     log.error(f"Invalid query vector dimensions or parameters")
                     log.error(f"Invalid query vector dimensions or parameters")
                     return None
                     return None
-                elif error_code == 'AccessDeniedException':
-                    log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
+                elif error_code == "AccessDeniedException":
+                    log.error(
+                        f"Access denied for collection '{collection_name}'. Check permissions."
+                    )
                     return None
                     return None
             raise
             raise
 
 
-    def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
+    def query(
+        self, collection_name: str, filter: Dict, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
         """
         """
         Query vectors from a collection using metadata filter.
         Query vectors from a collection using metadata filter.
         """
         """
-            
+
         if not self.has_collection(collection_name):
         if not self.has_collection(collection_name):
             log.warning(f"Collection '{collection_name}' does not exist")
             log.warning(f"Collection '{collection_name}' does not exist")
             return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
             return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-            
+
         if not filter:
         if not filter:
             log.warning("No filter provided, returning all vectors")
             log.warning("No filter provided, returning all vectors")
             return self.get(collection_name)
             return self.get(collection_name)
-            
+
         try:
         try:
             log.info(f"Querying collection '{collection_name}' with filter: {filter}")
             log.info(f"Querying collection '{collection_name}' with filter: {filter}")
-            
+
             # For S3 Vector, we need to use list_vectors and then filter results
             # For S3 Vector, we need to use list_vectors and then filter results
             # Since S3 Vector may not support complex server-side filtering,
             # Since S3 Vector may not support complex server-side filtering,
             # we'll retrieve all vectors and filter client-side
             # we'll retrieve all vectors and filter client-side
-            
+
             # Get all vectors first
             # Get all vectors first
             all_vectors_result = self.get(collection_name)
             all_vectors_result = self.get(collection_name)
-            
+
             if not all_vectors_result or not all_vectors_result.ids:
             if not all_vectors_result or not all_vectors_result.ids:
                 log.warning("No vectors found in collection")
                 log.warning("No vectors found in collection")
                 return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                 return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-                
+
             # Extract the lists from the result
             # Extract the lists from the result
             all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
             all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
-            all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else []
-            all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
-            
+            all_documents = (
+                all_vectors_result.documents[0] if all_vectors_result.documents else []
+            )
+            all_metadatas = (
+                all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
+            )
+
             # Apply client-side filtering
             # Apply client-side filtering
             filtered_ids = []
             filtered_ids = []
             filtered_documents = []
             filtered_documents = []
             filtered_metadatas = []
             filtered_metadatas = []
-            
+
             for i, metadata in enumerate(all_metadatas):
             for i, metadata in enumerate(all_metadatas):
                 if self._matches_filter(metadata, filter):
                 if self._matches_filter(metadata, filter):
                     if i < len(all_ids):
                     if i < len(all_ids):
@@ -388,29 +425,37 @@ class S3VectorClient(VectorDBBase):
                     if i < len(all_documents):
                     if i < len(all_documents):
                         filtered_documents.append(all_documents[i])
                         filtered_documents.append(all_documents[i])
                     filtered_metadatas.append(metadata)
                     filtered_metadatas.append(metadata)
-                    
+
                     # Apply limit if specified
                     # Apply limit if specified
                     if limit and len(filtered_ids) >= limit:
                     if limit and len(filtered_ids) >= limit:
                         break
                         break
-            
-            log.info(f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total")
-            
+
+            log.info(
+                f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
+            )
+
             # Return GetResult format
             # Return GetResult format
             if filtered_ids:
             if filtered_ids:
-                return GetResult(ids=[filtered_ids], documents=[filtered_documents], metadatas=[filtered_metadatas])
+                return GetResult(
+                    ids=[filtered_ids],
+                    documents=[filtered_documents],
+                    metadatas=[filtered_metadatas],
+                )
             else:
             else:
                 return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                 return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-            
+
         except Exception as e:
         except Exception as e:
             log.error(f"Error querying collection '{collection_name}': {str(e)}")
             log.error(f"Error querying collection '{collection_name}': {str(e)}")
             # Handle specific AWS exceptions
             # Handle specific AWS exceptions
-            if hasattr(e, 'response') and 'Error' in e.response:
-                error_code = e.response['Error']['Code']
-                if error_code == 'NotFoundException':
+            if hasattr(e, "response") and "Error" in e.response:
+                error_code = e.response["Error"]["Code"]
+                if error_code == "NotFoundException":
                     log.warning(f"Collection '{collection_name}' not found")
                     log.warning(f"Collection '{collection_name}' not found")
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-                elif error_code == 'AccessDeniedException':
-                    log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
+                elif error_code == "AccessDeniedException":
+                    log.error(
+                        f"Access denied for collection '{collection_name}'. Check permissions."
+                    )
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
             raise
             raise
 
 
@@ -418,170 +463,203 @@ class S3VectorClient(VectorDBBase):
         """
         """
         Retrieve all vectors from a collection.
         Retrieve all vectors from a collection.
         """
         """
-            
+
         if not self.has_collection(collection_name):
         if not self.has_collection(collection_name):
             log.warning(f"Collection '{collection_name}' does not exist")
             log.warning(f"Collection '{collection_name}' does not exist")
             return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
             return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-            
+
         try:
         try:
             log.info(f"Retrieving all vectors from collection '{collection_name}'")
             log.info(f"Retrieving all vectors from collection '{collection_name}'")
-            
+
             # Initialize result lists
             # Initialize result lists
             all_ids = []
             all_ids = []
             all_documents = []
             all_documents = []
             all_metadatas = []
             all_metadatas = []
-            
+
             # Handle pagination
             # Handle pagination
             next_token = None
             next_token = None
-            
+
             while True:
             while True:
                 # Prepare request parameters
                 # Prepare request parameters
                 request_params = {
                 request_params = {
-                    'vectorBucketName': self.bucket_name,
-                    'indexName': collection_name,
-                    'returnData': False,  # Don't include vector data (not needed for get)
-                    'returnMetadata': True,  # Include metadata
-                    'maxResults': 500  # Use reasonable page size
+                    "vectorBucketName": self.bucket_name,
+                    "indexName": collection_name,
+                    "returnData": False,  # Don't include vector data (not needed for get)
+                    "returnMetadata": True,  # Include metadata
+                    "maxResults": 500,  # Use reasonable page size
                 }
                 }
-                
+
                 if next_token:
                 if next_token:
-                    request_params['nextToken'] = next_token
-                
+                    request_params["nextToken"] = next_token
+
                 # Call S3 Vector API
                 # Call S3 Vector API
                 response = self.client.list_vectors(**request_params)
                 response = self.client.list_vectors(**request_params)
-                
+
                 # Process vectors in this page
                 # Process vectors in this page
-                vectors = response.get('vectors', [])
-                
+                vectors = response.get("vectors", [])
+
                 for vector in vectors:
                 for vector in vectors:
-                    vector_id = vector.get('key')
-                    vector_data = vector.get('data', {})
-                    vector_metadata = vector.get('metadata', {})
-                    
+                    vector_id = vector.get("key")
+                    vector_data = vector.get("data", {})
+                    vector_metadata = vector.get("metadata", {})
+
                     # Extract the actual vector array
                     # Extract the actual vector array
-                    vector_array = vector_data.get('float32', [])
-                    
+                    vector_array = vector_data.get("float32", [])
+
                     # For documents, we try to extract text from metadata or use the vector ID
                     # For documents, we try to extract text from metadata or use the vector ID
                     document_text = ""
                     document_text = ""
                     if isinstance(vector_metadata, dict):
                     if isinstance(vector_metadata, dict):
                         # Get the text field first (highest priority)
                         # Get the text field first (highest priority)
-                        document_text = vector_metadata.get('text')
+                        document_text = vector_metadata.get("text")
                         if not document_text:
                         if not document_text:
                             # Fallback to other possible text fields
                             # Fallback to other possible text fields
-                            document_text = (vector_metadata.get('content') or 
-                                           vector_metadata.get('document') or 
-                                           vector_id)
-                        
+                            document_text = (
+                                vector_metadata.get("content")
+                                or vector_metadata.get("document")
+                                or vector_id
+                            )
+
                         # Log the actual content for debugging
                         # Log the actual content for debugging
-                        log.debug(f"Document text preview (first 200 chars): {str(document_text)[:200]}")
+                        log.debug(
+                            f"Document text preview (first 200 chars): {str(document_text)[:200]}"
+                        )
                     else:
                     else:
                         document_text = vector_id
                         document_text = vector_id
-                    
+
                     all_ids.append(vector_id)
                     all_ids.append(vector_id)
                     all_documents.append(document_text)
                     all_documents.append(document_text)
                     all_metadatas.append(vector_metadata)
                     all_metadatas.append(vector_metadata)
-                
+
                 # Check if there are more pages
                 # Check if there are more pages
-                next_token = response.get('nextToken')
+                next_token = response.get("nextToken")
                 if not next_token:
                 if not next_token:
                     break
                     break
-            
-            log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'")
-            
+
+            log.info(
+                f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
+            )
+
             # Return in GetResult format
             # Return in GetResult format
             # The Open WebUI GetResult expects lists of lists, so we wrap each list
             # The Open WebUI GetResult expects lists of lists, so we wrap each list
             if all_ids:
             if all_ids:
-                return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas])
+                return GetResult(
+                    ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
+                )
             else:
             else:
                 return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                 return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-            
+
         except Exception as e:
         except Exception as e:
-            log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}")
+            log.error(
+                f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
+            )
             # Handle specific AWS exceptions
             # Handle specific AWS exceptions
-            if hasattr(e, 'response') and 'Error' in e.response:
-                error_code = e.response['Error']['Code']
-                if error_code == 'NotFoundException':
+            if hasattr(e, "response") and "Error" in e.response:
+                error_code = e.response["Error"]["Code"]
+                if error_code == "NotFoundException":
                     log.warning(f"Collection '{collection_name}' not found")
                     log.warning(f"Collection '{collection_name}' not found")
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
-                elif error_code == 'AccessDeniedException':
-                    log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
+                elif error_code == "AccessDeniedException":
+                    log.error(
+                        f"Access denied for collection '{collection_name}'. Check permissions."
+                    )
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
                     return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
             raise
             raise
 
 
-    def delete(self, collection_name: str, ids: Optional[List[str]] = None, filter: Optional[Dict] = None) -> None:
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict] = None,
+    ) -> None:
         """
         """
         Delete vectors by ID or filter from a collection.
         Delete vectors by ID or filter from a collection.
         """
         """
-            
+
         if not self.has_collection(collection_name):
         if not self.has_collection(collection_name):
-            log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
+            log.warning(
+                f"Collection '{collection_name}' does not exist, nothing to delete"
+            )
             return
             return
-            
+
         # Check if this is a knowledge collection (not file-specific)
         # Check if this is a knowledge collection (not file-specific)
         is_knowledge_collection = not collection_name.startswith("file-")
         is_knowledge_collection = not collection_name.startswith("file-")
-            
+
         try:
         try:
             if ids:
             if ids:
                 # Delete by specific vector IDs/keys
                 # Delete by specific vector IDs/keys
-                log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'")
+                log.info(
+                    f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
+                )
                 self.client.delete_vectors(
                 self.client.delete_vectors(
                     vectorBucketName=self.bucket_name,
                     vectorBucketName=self.bucket_name,
                     indexName=collection_name,
                     indexName=collection_name,
-                    keys=ids
+                    keys=ids,
                 )
                 )
                 log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
                 log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
-                        
+
             elif filter:
             elif filter:
                 # Handle filter-based deletion
                 # Handle filter-based deletion
-                log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}")
-                
+                log.info(
+                    f"Deleting vectors by filter from collection '{collection_name}': {filter}"
+                )
+
                 # If this is a knowledge collection and we have a file_id filter,
                 # If this is a knowledge collection and we have a file_id filter,
                 # also clean up the corresponding file-specific collection
                 # also clean up the corresponding file-specific collection
                 if is_knowledge_collection and "file_id" in filter:
                 if is_knowledge_collection and "file_id" in filter:
                     file_id = filter["file_id"]
                     file_id = filter["file_id"]
                     file_collection_name = f"file-{file_id}"
                     file_collection_name = f"file-{file_id}"
                     if self.has_collection(file_collection_name):
                     if self.has_collection(file_collection_name):
-                        log.info(f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates")
+                        log.info(
+                            f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
+                        )
                         self.delete_collection(file_collection_name)
                         self.delete_collection(file_collection_name)
-                
+
                 # For the main collection, implement query-then-delete
                 # For the main collection, implement query-then-delete
                 # First, query to get IDs matching the filter
                 # First, query to get IDs matching the filter
                 query_result = self.query(collection_name, filter)
                 query_result = self.query(collection_name, filter)
                 if query_result and query_result.ids and query_result.ids[0]:
                 if query_result and query_result.ids and query_result.ids[0]:
                     matching_ids = query_result.ids[0]
                     matching_ids = query_result.ids[0]
-                    log.info(f"Found {len(matching_ids)} vectors matching filter, deleting them")
-                    
+                    log.info(
+                        f"Found {len(matching_ids)} vectors matching filter, deleting them"
+                    )
+
                     # Delete the matching vectors by ID
                     # Delete the matching vectors by ID
                     self.client.delete_vectors(
                     self.client.delete_vectors(
                         vectorBucketName=self.bucket_name,
                         vectorBucketName=self.bucket_name,
                         indexName=collection_name,
                         indexName=collection_name,
-                        keys=matching_ids
+                        keys=matching_ids,
+                    )
+                    log.info(
+                        f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
                     )
                     )
-                    log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter")
                 else:
                 else:
                     log.warning("No vectors found matching the filter criteria")
                     log.warning("No vectors found matching the filter criteria")
             else:
             else:
                 log.warning("No IDs or filter provided for deletion")
                 log.warning("No IDs or filter provided for deletion")
         except Exception as e:
         except Exception as e:
-            log.error(f"Error deleting vectors from collection '{collection_name}': {e}")
+            log.error(
+                f"Error deleting vectors from collection '{collection_name}': {e}"
+            )
             raise
             raise
 
 
     def reset(self) -> None:
     def reset(self) -> None:
         """
         """
         Reset/clear all vector data. For S3 Vector, this deletes all indexes.
         Reset/clear all vector data. For S3 Vector, this deletes all indexes.
         """
         """
-            
+
         try:
         try:
-            log.warning("Reset called - this will delete all vector indexes in the S3 bucket")
-            
+            log.warning(
+                "Reset called - this will delete all vector indexes in the S3 bucket"
+            )
+
             # List all indexes
             # List all indexes
             response = self.client.list_indexes(vectorBucketName=self.bucket_name)
             response = self.client.list_indexes(vectorBucketName=self.bucket_name)
             indexes = response.get("indexes", [])
             indexes = response.get("indexes", [])
-            
+
             if not indexes:
             if not indexes:
                 log.warning("No indexes found to delete")
                 log.warning("No indexes found to delete")
                 return
                 return
-                
+
             # Delete all indexes
             # Delete all indexes
             deleted_count = 0
             deleted_count = 0
             for index in indexes:
             for index in indexes:
@@ -589,39 +667,38 @@ class S3VectorClient(VectorDBBase):
                 if index_name:
                 if index_name:
                     try:
                     try:
                         self.client.delete_index(
                         self.client.delete_index(
-                            vectorBucketName=self.bucket_name,
-                            indexName=index_name
+                            vectorBucketName=self.bucket_name, indexName=index_name
                         )
                         )
                         deleted_count += 1
                         deleted_count += 1
                         log.info(f"Deleted index: {index_name}")
                         log.info(f"Deleted index: {index_name}")
                     except Exception as e:
                     except Exception as e:
                         log.error(f"Error deleting index '{index_name}': {e}")
                         log.error(f"Error deleting index '{index_name}': {e}")
-                        
+
             log.info(f"Reset completed: deleted {deleted_count} indexes")
             log.info(f"Reset completed: deleted {deleted_count} indexes")
-            
+
         except Exception as e:
         except Exception as e:
             log.error(f"Error during reset: {e}")
             log.error(f"Error during reset: {e}")
             raise
             raise
-    
+
     def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
     def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
         """
         """
         Check if metadata matches the given filter conditions.
         Check if metadata matches the given filter conditions.
         """
         """
         if not isinstance(metadata, dict) or not isinstance(filter, dict):
         if not isinstance(metadata, dict) or not isinstance(filter, dict):
             return False
             return False
-            
+
         # Check each filter condition
         # Check each filter condition
         for key, expected_value in filter.items():
         for key, expected_value in filter.items():
             # Handle special operators
             # Handle special operators
-            if key.startswith('$'):
-                if key == '$and':
+            if key.startswith("$"):
+                if key == "$and":
                     # All conditions must match
                     # All conditions must match
                     if not isinstance(expected_value, list):
                     if not isinstance(expected_value, list):
                         continue
                         continue
                     for condition in expected_value:
                     for condition in expected_value:
                         if not self._matches_filter(metadata, condition):
                         if not self._matches_filter(metadata, condition):
                             return False
                             return False
-                elif key == '$or':
+                elif key == "$or":
                     # At least one condition must match
                     # At least one condition must match
                     if not isinstance(expected_value, list):
                     if not isinstance(expected_value, list):
                         continue
                         continue
@@ -633,27 +710,30 @@ class S3VectorClient(VectorDBBase):
                     if not any_match:
                     if not any_match:
                         return False
                         return False
                 continue
                 continue
-            
+
             # Get the actual value from metadata
             # Get the actual value from metadata
             actual_value = metadata.get(key)
             actual_value = metadata.get(key)
-            
+
             # Handle different types of expected values
             # Handle different types of expected values
             if isinstance(expected_value, dict):
             if isinstance(expected_value, dict):
                 # Handle comparison operators
                 # Handle comparison operators
                 for op, op_value in expected_value.items():
                 for op, op_value in expected_value.items():
-                    if op == '$eq':
+                    if op == "$eq":
                         if actual_value != op_value:
                         if actual_value != op_value:
                             return False
                             return False
-                    elif op == '$ne':
+                    elif op == "$ne":
                         if actual_value == op_value:
                         if actual_value == op_value:
                             return False
                             return False
-                    elif op == '$in':
-                        if not isinstance(op_value, list) or actual_value not in op_value:
+                    elif op == "$in":
+                        if (
+                            not isinstance(op_value, list)
+                            or actual_value not in op_value
+                        ):
                             return False
                             return False
-                    elif op == '$nin':
+                    elif op == "$nin":
                         if isinstance(op_value, list) and actual_value in op_value:
                         if isinstance(op_value, list) and actual_value in op_value:
                             return False
                             return False
-                    elif op == '$exists':
+                    elif op == "$exists":
                         if bool(op_value) != (key in metadata):
                         if bool(op_value) != (key in metadata):
                             return False
                             return False
                     # Add more operators as needed
                     # Add more operators as needed
@@ -661,5 +741,5 @@ class S3VectorClient(VectorDBBase):
                 # Simple equality check
                 # Simple equality check
                 if actual_value != expected_value:
                 if actual_value != expected_value:
                     return False
                     return False
-                    
+
         return True
         return True

+ 2 - 0
backend/open_webui/retrieval/vector/factory.py

@@ -28,9 +28,11 @@ class Vector:
                     return QdrantClient()
                     return QdrantClient()
             case VectorType.PINECONE:
             case VectorType.PINECONE:
                 from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
                 from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
+
                 return PineconeClient()
                 return PineconeClient()
             case VectorType.S3VECTOR:
             case VectorType.S3VECTOR:
                 from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
                 from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
+
                 return S3VectorClient()
                 return S3VectorClient()
             case VectorType.OPENSEARCH:
             case VectorType.OPENSEARCH:
                 from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
                 from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient