123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581 |
- from typing import Optional, List, Dict, Any, Union
- import logging
- import time # for measuring elapsed time
- from pinecone import Pinecone, ServerlessSpec
- # Add gRPC support for better performance (Pinecone best practice)
- try:
- from pinecone.grpc import PineconeGRPC
- GRPC_AVAILABLE = True
- except ImportError:
- GRPC_AVAILABLE = False
- import asyncio # for async upserts
- import functools # for partial binding in async tasks
- import concurrent.futures # for parallel batch upserts
- import random # for jitter in retry backoff
- from open_webui.retrieval.vector.main import (
- VectorDBBase,
- VectorItem,
- SearchResult,
- GetResult,
- )
- from open_webui.config import (
- PINECONE_API_KEY,
- PINECONE_ENVIRONMENT,
- PINECONE_INDEX_NAME,
- PINECONE_DIMENSION,
- PINECONE_METRIC,
- PINECONE_CLOUD,
- )
- from open_webui.env import SRC_LOG_LEVELS
- NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
- BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["RAG"])
- class PineconeClient(VectorDBBase):
- def __init__(self):
- self.collection_prefix = "open-webui"
- # Validate required configuration
- self._validate_config()
- # Store configuration values
- self.api_key = PINECONE_API_KEY
- self.environment = PINECONE_ENVIRONMENT
- self.index_name = PINECONE_INDEX_NAME
- self.dimension = PINECONE_DIMENSION
- self.metric = PINECONE_METRIC
- self.cloud = PINECONE_CLOUD
- # Initialize Pinecone client for improved performance
- if GRPC_AVAILABLE:
- # Use gRPC client for better performance (Pinecone recommendation)
- self.client = PineconeGRPC(
- api_key=self.api_key,
- pool_threads=20, # Improved connection pool size
- timeout=30, # Reasonable timeout for operations
- )
- self.using_grpc = True
- log.info("Using Pinecone gRPC client for optimal performance")
- else:
- # Fallback to HTTP client with enhanced connection pooling
- self.client = Pinecone(
- api_key=self.api_key,
- pool_threads=20, # Improved connection pool size
- timeout=30, # Reasonable timeout for operations
- )
- self.using_grpc = False
- log.info("Using Pinecone HTTP client (gRPC not available)")
- # Persistent executor for batch operations
- self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
- # Create index if it doesn't exist
- self._initialize_index()
- def _validate_config(self) -> None:
- """Validate that all required configuration variables are set."""
- missing_vars = []
- if not PINECONE_API_KEY:
- missing_vars.append("PINECONE_API_KEY")
- if not PINECONE_ENVIRONMENT:
- missing_vars.append("PINECONE_ENVIRONMENT")
- if not PINECONE_INDEX_NAME:
- missing_vars.append("PINECONE_INDEX_NAME")
- if not PINECONE_DIMENSION:
- missing_vars.append("PINECONE_DIMENSION")
- if not PINECONE_CLOUD:
- missing_vars.append("PINECONE_CLOUD")
- if missing_vars:
- raise ValueError(
- f"Required configuration missing: {', '.join(missing_vars)}"
- )
- def _initialize_index(self) -> None:
- """Initialize the Pinecone index."""
- try:
- # Check if index exists
- if self.index_name not in self.client.list_indexes().names():
- log.info(f"Creating Pinecone index '{self.index_name}'...")
- self.client.create_index(
- name=self.index_name,
- dimension=self.dimension,
- metric=self.metric,
- spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
- )
- log.info(f"Successfully created Pinecone index '{self.index_name}'")
- else:
- log.info(f"Using existing Pinecone index '{self.index_name}'")
- # Connect to the index
- self.index = self.client.Index(
- self.index_name,
- pool_threads=20, # Enhanced connection pool for index operations
- )
- except Exception as e:
- log.error(f"Failed to initialize Pinecone index: {e}")
- raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
- def _retry_pinecone_operation(self, operation_func, max_retries=3):
- """Retry Pinecone operations with exponential backoff for rate limits and network issues."""
- for attempt in range(max_retries):
- try:
- return operation_func()
- except Exception as e:
- error_str = str(e).lower()
- # Check if it's a retryable error (rate limits, network issues, timeouts)
- is_retryable = any(
- keyword in error_str
- for keyword in [
- "rate limit",
- "quota",
- "timeout",
- "network",
- "connection",
- "unavailable",
- "internal error",
- "429",
- "500",
- "502",
- "503",
- "504",
- ]
- )
- if not is_retryable or attempt == max_retries - 1:
- # Don't retry for non-retryable errors or on final attempt
- raise
- # Exponential backoff with jitter
- delay = (2**attempt) + random.uniform(0, 1)
- log.warning(
- f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
- f"retrying in {delay:.2f}s: {e}"
- )
- time.sleep(delay)
- def _create_points(
- self, items: List[VectorItem], collection_name_with_prefix: str
- ) -> List[Dict[str, Any]]:
- """Convert VectorItem objects to Pinecone point format."""
- points = []
- for item in items:
- # Start with any existing metadata or an empty dict
- metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
- # Add text to metadata if available
- if "text" in item:
- metadata["text"] = item["text"]
- # Always add collection_name to metadata for filtering
- metadata["collection_name"] = collection_name_with_prefix
- point = {
- "id": item["id"],
- "values": item["vector"],
- "metadata": metadata,
- }
- points.append(point)
- return points
- def _get_collection_name_with_prefix(self, collection_name: str) -> str:
- """Get the collection name with prefix."""
- return f"{self.collection_prefix}_{collection_name}"
- def _normalize_distance(self, score: float) -> float:
- """Normalize distance score based on the metric used."""
- if self.metric.lower() == "cosine":
- # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
- return (score + 1.0) / 2.0
- elif self.metric.lower() in ["euclidean", "dotproduct"]:
- # These are already suitable for ranking (smaller is better for Euclidean)
- return score
- else:
- # For other metrics, use as is
- return score
- def _result_to_get_result(self, matches: list) -> GetResult:
- """Convert Pinecone matches to GetResult format."""
- ids = []
- documents = []
- metadatas = []
- for match in matches:
- metadata = getattr(match, "metadata", {}) or {}
- ids.append(match.id if hasattr(match, "id") else match["id"])
- documents.append(metadata.get("text", ""))
- metadatas.append(metadata)
- return GetResult(
- **{
- "ids": [ids],
- "documents": [documents],
- "metadatas": [metadatas],
- }
- )
- def has_collection(self, collection_name: str) -> bool:
- """Check if a collection exists by searching for at least one item."""
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- try:
- # Search for at least 1 item with this collection name in metadata
- response = self.index.query(
- vector=[0.0] * self.dimension, # dummy vector
- top_k=1,
- filter={"collection_name": collection_name_with_prefix},
- include_metadata=False,
- )
- matches = getattr(response, "matches", []) or []
- return len(matches) > 0
- except Exception as e:
- log.exception(
- f"Error checking collection '{collection_name_with_prefix}': {e}"
- )
- return False
- def delete_collection(self, collection_name: str) -> None:
- """Delete a collection by removing all vectors with the collection name in metadata."""
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- try:
- self.index.delete(filter={"collection_name": collection_name_with_prefix})
- log.info(
- f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
- )
- except Exception as e:
- log.warning(
- f"Failed to delete collection '{collection_name_with_prefix}': {e}"
- )
- raise
- def insert(self, collection_name: str, items: List[VectorItem]) -> None:
- """Insert vectors into a collection."""
- if not items:
- log.warning("No items to insert")
- return
- start_time = time.time()
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- points = self._create_points(items, collection_name_with_prefix)
- # Parallelize batch inserts for performance
- executor = self._executor
- futures = []
- for i in range(0, len(points), BATCH_SIZE):
- batch = points[i : i + BATCH_SIZE]
- futures.append(executor.submit(self.index.upsert, vectors=batch))
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- log.error(f"Error inserting batch: {e}")
- raise
- elapsed = time.time() - start_time
- log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
- log.info(
- f"Successfully inserted {len(points)} vectors in parallel batches "
- f"into '{collection_name_with_prefix}'"
- )
- def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
- """Upsert (insert or update) vectors into a collection."""
- if not items:
- log.warning("No items to upsert")
- return
- start_time = time.time()
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- points = self._create_points(items, collection_name_with_prefix)
- # Parallelize batch upserts for performance
- executor = self._executor
- futures = []
- for i in range(0, len(points), BATCH_SIZE):
- batch = points[i : i + BATCH_SIZE]
- futures.append(executor.submit(self.index.upsert, vectors=batch))
- for future in concurrent.futures.as_completed(futures):
- try:
- future.result()
- except Exception as e:
- log.error(f"Error upserting batch: {e}")
- raise
- elapsed = time.time() - start_time
- log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
- log.info(
- f"Successfully upserted {len(points)} vectors in parallel batches "
- f"into '{collection_name_with_prefix}'"
- )
- async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
- """Async version of insert using asyncio and run_in_executor for improved performance."""
- if not items:
- log.warning("No items to insert")
- return
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- points = self._create_points(items, collection_name_with_prefix)
- # Create batches
- batches = [
- points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
- ]
- loop = asyncio.get_event_loop()
- tasks = [
- loop.run_in_executor(
- None, functools.partial(self.index.upsert, vectors=batch)
- )
- for batch in batches
- ]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for result in results:
- if isinstance(result, Exception):
- log.error(f"Error in async insert batch: {result}")
- raise result
- log.info(
- f"Successfully async inserted {len(points)} vectors in batches "
- f"into '{collection_name_with_prefix}'"
- )
- async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
- """Async version of upsert using asyncio and run_in_executor for improved performance."""
- if not items:
- log.warning("No items to upsert")
- return
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- points = self._create_points(items, collection_name_with_prefix)
- # Create batches
- batches = [
- points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
- ]
- loop = asyncio.get_event_loop()
- tasks = [
- loop.run_in_executor(
- None, functools.partial(self.index.upsert, vectors=batch)
- )
- for batch in batches
- ]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for result in results:
- if isinstance(result, Exception):
- log.error(f"Error in async upsert batch: {result}")
- raise result
- log.info(
- f"Successfully async upserted {len(points)} vectors in batches "
- f"into '{collection_name_with_prefix}'"
- )
- def search(
- self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
- ) -> Optional[SearchResult]:
- """Search for similar vectors in a collection."""
- if not vectors or not vectors[0]:
- log.warning("No vectors provided for search")
- return None
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- if limit is None or limit <= 0:
- limit = NO_LIMIT
- try:
- # Search using the first vector (assuming this is the intended behavior)
- query_vector = vectors[0]
- # Perform the search
- query_response = self.index.query(
- vector=query_vector,
- top_k=limit,
- include_metadata=True,
- filter={"collection_name": collection_name_with_prefix},
- )
- matches = getattr(query_response, "matches", []) or []
- if not matches:
- # Return empty result if no matches
- return SearchResult(
- ids=[[]],
- documents=[[]],
- metadatas=[[]],
- distances=[[]],
- )
- # Convert to GetResult format
- get_result = self._result_to_get_result(matches)
- # Calculate normalized distances based on metric
- distances = [
- [
- self._normalize_distance(getattr(match, "score", 0.0))
- for match in matches
- ]
- ]
- return SearchResult(
- ids=get_result.ids,
- documents=get_result.documents,
- metadatas=get_result.metadatas,
- distances=distances,
- )
- except Exception as e:
- log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
- return None
- def query(
- self, collection_name: str, filter: Dict, limit: Optional[int] = None
- ) -> Optional[GetResult]:
- """Query vectors by metadata filter."""
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- if limit is None or limit <= 0:
- limit = NO_LIMIT
- try:
- # Create a zero vector for the dimension as Pinecone requires a vector
- zero_vector = [0.0] * self.dimension
- # Combine user filter with collection_name
- pinecone_filter = {"collection_name": collection_name_with_prefix}
- if filter:
- pinecone_filter.update(filter)
- # Perform metadata-only query
- query_response = self.index.query(
- vector=zero_vector,
- filter=pinecone_filter,
- top_k=limit,
- include_metadata=True,
- )
- matches = getattr(query_response, "matches", []) or []
- return self._result_to_get_result(matches)
- except Exception as e:
- log.error(f"Error querying collection '{collection_name}': {e}")
- return None
- def get(self, collection_name: str) -> Optional[GetResult]:
- """Get all vectors in a collection."""
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- try:
- # Use a zero vector for fetching all entries
- zero_vector = [0.0] * self.dimension
- # Add filter to only get vectors for this collection
- query_response = self.index.query(
- vector=zero_vector,
- top_k=NO_LIMIT,
- include_metadata=True,
- filter={"collection_name": collection_name_with_prefix},
- )
- matches = getattr(query_response, "matches", []) or []
- return self._result_to_get_result(matches)
- except Exception as e:
- log.error(f"Error getting collection '{collection_name}': {e}")
- return None
- def delete(
- self,
- collection_name: str,
- ids: Optional[List[str]] = None,
- filter: Optional[Dict] = None,
- ) -> None:
- """Delete vectors by IDs or filter."""
- collection_name_with_prefix = self._get_collection_name_with_prefix(
- collection_name
- )
- try:
- if ids:
- # Delete by IDs (in batches for large deletions)
- for i in range(0, len(ids), BATCH_SIZE):
- batch_ids = ids[i : i + BATCH_SIZE]
- # Note: When deleting by ID, we can't filter by collection_name
- # This is a limitation of Pinecone - be careful with ID uniqueness
- self.index.delete(ids=batch_ids)
- log.debug(
- f"Deleted batch of {len(batch_ids)} vectors by ID "
- f"from '{collection_name_with_prefix}'"
- )
- log.info(
- f"Successfully deleted {len(ids)} vectors by ID "
- f"from '{collection_name_with_prefix}'"
- )
- elif filter:
- # Combine user filter with collection_name
- pinecone_filter = {"collection_name": collection_name_with_prefix}
- if filter:
- pinecone_filter.update(filter)
- # Delete by metadata filter
- self.index.delete(filter=pinecone_filter)
- log.info(
- f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
- )
- else:
- log.warning("No ids or filter provided for delete operation")
- except Exception as e:
- log.error(f"Error deleting from collection '{collection_name}': {e}")
- raise
- def reset(self) -> None:
- """Reset the database by deleting all collections."""
- try:
- self.index.delete(delete_all=True)
- log.info("All vectors successfully deleted from the index.")
- except Exception as e:
- log.error(f"Failed to reset Pinecone index: {e}")
- raise
- def close(self):
- """Shut down resources."""
- try:
- # The new Pinecone client doesn't need explicit closing
- pass
- except Exception as e:
- log.warning(f"Failed to clean up Pinecone resources: {e}")
- self._executor.shutdown(wait=True)
- def __enter__(self):
- """Enter context manager."""
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- """Exit context manager, ensuring resources are cleaned up."""
- self.close()
|