Browse Source

Merge pull request #13073 from athoik/dev

feat(vector-db): add support for Pinecone client
Tim Jaeryang Baek 3 months ago
parent
commit
913f8a15f9

+ 8 - 0
backend/open_webui/config.py

@@ -1736,6 +1736,14 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
     os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
 )
 
+# Pinecone
+PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
+PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
+PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "open-webui-index")
+PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536))  # or 3072, 1024, 768
+PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine")
+PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws")  # or "gcp" or "azure"
+
 ####################################
 # Information Retrieval (RAG)
 ####################################

+ 4 - 0
backend/open_webui/retrieval/vector/connector.py

@@ -20,6 +20,10 @@ elif VECTOR_DB == "elasticsearch":
     from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
 
     VECTOR_DB_CLIENT = ElasticsearchClient()
+elif VECTOR_DB == "pinecone":
+    from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
+
+    VECTOR_DB_CLIENT = PineconeClient()
 else:
     from open_webui.retrieval.vector.dbs.chroma import ChromaClient
 

+ 407 - 0
backend/open_webui/retrieval/vector/dbs/pinecone.py

@@ -0,0 +1,407 @@
+from typing import Optional, List, Dict, Any, Union
+import logging
+from pinecone import Pinecone, ServerlessSpec
+
+from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.config import (
+    PINECONE_API_KEY,
+    PINECONE_ENVIRONMENT,
+    PINECONE_INDEX_NAME,
+    PINECONE_DIMENSION,
+    PINECONE_METRIC,
+    PINECONE_CLOUD,
+)
+from open_webui.env import SRC_LOG_LEVELS
+
+NO_LIMIT = 10000  # Reasonable limit to avoid overwhelming the system
+BATCH_SIZE = 100  # Recommended batch size for Pinecone operations
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+class PineconeClient:
+    def __init__(self):
+        self.collection_prefix = "open-webui"
+
+        # Validate required configuration
+        self._validate_config()
+
+        # Store configuration values
+        self.api_key = PINECONE_API_KEY
+        self.environment = PINECONE_ENVIRONMENT
+        self.index_name = PINECONE_INDEX_NAME
+        self.dimension = PINECONE_DIMENSION
+        self.metric = PINECONE_METRIC
+        self.cloud = PINECONE_CLOUD
+
+        # Initialize Pinecone client
+        self.client = Pinecone(api_key=self.api_key)
+
+        # Create index if it doesn't exist
+        self._initialize_index()
+
+    def _validate_config(self) -> None:
+        """Validate that all required configuration variables are set."""
+        missing_vars = []
+        if not PINECONE_API_KEY:
+            missing_vars.append("PINECONE_API_KEY")
+        if not PINECONE_ENVIRONMENT:
+            missing_vars.append("PINECONE_ENVIRONMENT")
+        if not PINECONE_INDEX_NAME:
+            missing_vars.append("PINECONE_INDEX_NAME")
+        if not PINECONE_DIMENSION:
+            missing_vars.append("PINECONE_DIMENSION")
+        if not PINECONE_CLOUD:
+            missing_vars.append("PINECONE_CLOUD")
+
+        if missing_vars:
+            raise ValueError(
+                f"Required configuration missing: {', '.join(missing_vars)}"
+            )
+
+    def _initialize_index(self) -> None:
+        """Initialize the Pinecone index."""
+        try:
+            # Check if index exists
+            if self.index_name not in self.client.list_indexes().names():
+                log.info(f"Creating Pinecone index '{self.index_name}'...")
+                self.client.create_index(
+                    name=self.index_name,
+                    dimension=self.dimension,
+                    metric=self.metric,
+                    spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
+                )
+                log.info(f"Successfully created Pinecone index '{self.index_name}'")
+            else:
+                log.info(f"Using existing Pinecone index '{self.index_name}'")
+
+            # Connect to the index
+            self.index = self.client.Index(self.index_name)
+
+        except Exception as e:
+            log.error(f"Failed to initialize Pinecone index: {e}")
+            raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
+
+    def _create_points(
+        self, items: List[VectorItem], collection_name_with_prefix: str
+    ) -> List[Dict[str, Any]]:
+        """Convert VectorItem objects to Pinecone point format."""
+        points = []
+        for item in items:
+            # Start with any existing metadata or an empty dict
+            metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
+
+            # Add text to metadata if available
+            if "text" in item:
+                metadata["text"] = item["text"]
+
+            # Always add collection_name to metadata for filtering
+            metadata["collection_name"] = collection_name_with_prefix
+
+            point = {
+                "id": item["id"],
+                "values": item["vector"],
+                "metadata": metadata,
+            }
+            points.append(point)
+        return points
+
+    def _get_collection_name_with_prefix(self, collection_name: str) -> str:
+        """Get the collection name with prefix."""
+        return f"{self.collection_prefix}_{collection_name}"
+
+    def _normalize_distance(self, score: float) -> float:
+        """Normalize distance score based on the metric used."""
+        if self.metric.lower() == "cosine":
+            # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
+            return (score + 1.0) / 2.0
+        elif self.metric.lower() in ["euclidean", "dotproduct"]:
+            # These are already suitable for ranking (smaller is better for Euclidean)
+            return score
+        else:
+            # For other metrics, use as is
+            return score
+
+    def _result_to_get_result(self, matches: list) -> GetResult:
+        """Convert Pinecone matches to GetResult format."""
+        ids = []
+        documents = []
+        metadatas = []
+
+        for match in matches:
+            metadata = match.get("metadata", {})
+            ids.append(match["id"])
+            documents.append(metadata.get("text", ""))
+            metadatas.append(metadata)
+
+        return GetResult(
+            **{
+                "ids": [ids],
+                "documents": [documents],
+                "metadatas": [metadatas],
+            }
+        )
+
+    def has_collection(self, collection_name: str) -> bool:
+        """Check if a collection exists by searching for at least one item."""
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+
+        try:
+            # Search for at least 1 item with this collection name in metadata
+            response = self.index.query(
+                vector=[0.0] * self.dimension,  # dummy vector
+                top_k=1,
+                filter={"collection_name": collection_name_with_prefix},
+                include_metadata=False,
+            )
+            return len(response.matches) > 0
+        except Exception as e:
+            log.exception(
+                f"Error checking collection '{collection_name_with_prefix}': {e}"
+            )
+            return False
+
+    def delete_collection(self, collection_name: str) -> None:
+        """Delete a collection by removing all vectors with the collection name in metadata."""
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+        try:
+            self.index.delete(filter={"collection_name": collection_name_with_prefix})
+            log.info(
+                f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
+            )
+        except Exception as e:
+            log.warning(
+                f"Failed to delete collection '{collection_name_with_prefix}': {e}"
+            )
+            raise
+
+    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Insert vectors into a collection."""
+        if not items:
+            log.warning("No items to insert")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Insert in batches for better performance and reliability
+        for i in range(0, len(points), BATCH_SIZE):
+            batch = points[i : i + BATCH_SIZE]
+            try:
+                self.index.upsert(vectors=batch)
+                log.debug(
+                    f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
+                )
+            except Exception as e:
+                log.error(
+                    f"Error inserting batch into '{collection_name_with_prefix}': {e}"
+                )
+                raise
+
+        log.info(
+            f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
+        )
+
+    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Upsert (insert or update) vectors into a collection."""
+        if not items:
+            log.warning("No items to upsert")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Upsert in batches
+        for i in range(0, len(points), BATCH_SIZE):
+            batch = points[i : i + BATCH_SIZE]
+            try:
+                self.index.upsert(vectors=batch)
+                log.debug(
+                    f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
+                )
+            except Exception as e:
+                log.error(
+                    f"Error upserting batch into '{collection_name_with_prefix}': {e}"
+                )
+                raise
+
+        log.info(
+            f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
+        )
+
+    def search(
+        self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
+    ) -> Optional[SearchResult]:
+        """Search for similar vectors in a collection."""
+        if not vectors or not vectors[0]:
+            log.warning("No vectors provided for search")
+            return None
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+
+        if limit is None or limit <= 0:
+            limit = NO_LIMIT
+
+        try:
+            # Search using the first vector (assuming this is the intended behavior)
+            query_vector = vectors[0]
+
+            # Perform the search
+            query_response = self.index.query(
+                vector=query_vector,
+                top_k=limit,
+                include_metadata=True,
+                filter={"collection_name": collection_name_with_prefix},
+            )
+
+            if not query_response.matches:
+                # Return empty result if no matches
+                return SearchResult(
+                    ids=[[]],
+                    documents=[[]],
+                    metadatas=[[]],
+                    distances=[[]],
+                )
+
+            # Convert to GetResult format
+            get_result = self._result_to_get_result(query_response.matches)
+
+            # Calculate normalized distances based on metric
+            distances = [
+                [
+                    self._normalize_distance(match.score)
+                    for match in query_response.matches
+                ]
+            ]
+
+            return SearchResult(
+                ids=get_result.ids,
+                documents=get_result.documents,
+                metadatas=get_result.metadatas,
+                distances=distances,
+            )
+        except Exception as e:
+            log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
+            return None
+
+    def query(
+        self, collection_name: str, filter: Dict, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        """Query vectors by metadata filter."""
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+
+        if limit is None or limit <= 0:
+            limit = NO_LIMIT
+
+        try:
+            # Create a zero vector for the dimension as Pinecone requires a vector
+            zero_vector = [0.0] * self.dimension
+
+            # Combine user filter with collection_name
+            pinecone_filter = {"collection_name": collection_name_with_prefix}
+            if filter:
+                pinecone_filter.update(filter)
+
+            # Perform metadata-only query
+            query_response = self.index.query(
+                vector=zero_vector,
+                filter=pinecone_filter,
+                top_k=limit,
+                include_metadata=True,
+            )
+
+            return self._result_to_get_result(query_response.matches)
+
+        except Exception as e:
+            log.error(f"Error querying collection '{collection_name}': {e}")
+            return None
+
+    def get(self, collection_name: str) -> Optional[GetResult]:
+        """Get all vectors in a collection."""
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+
+        try:
+            # Use a zero vector for fetching all entries
+            zero_vector = [0.0] * self.dimension
+
+            # Add filter to only get vectors for this collection
+            query_response = self.index.query(
+                vector=zero_vector,
+                top_k=NO_LIMIT,
+                include_metadata=True,
+                filter={"collection_name": collection_name_with_prefix},
+            )
+
+            return self._result_to_get_result(query_response.matches)
+
+        except Exception as e:
+            log.error(f"Error getting collection '{collection_name}': {e}")
+            return None
+
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict] = None,
+    ) -> None:
+        """Delete vectors by IDs or filter."""
+        collection_name_with_prefix = self._get_collection_name_with_prefix(
+            collection_name
+        )
+
+        try:
+            if ids:
+                # Delete by IDs (in batches for large deletions)
+                for i in range(0, len(ids), BATCH_SIZE):
+                    batch_ids = ids[i : i + BATCH_SIZE]
+                    # Note: When deleting by ID, we can't filter by collection_name
+                    # This is a limitation of Pinecone - be careful with ID uniqueness
+                    self.index.delete(ids=batch_ids)
+                    log.debug(
+                        f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
+                    )
+                log.info(
+                    f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
+                )
+
+            elif filter:
+                # Combine user filter with collection_name
+                pinecone_filter = {"collection_name": collection_name_with_prefix}
+                if filter:
+                    pinecone_filter.update(filter)
+                # Delete by metadata filter
+                self.index.delete(filter=pinecone_filter)
+                log.info(
+                    f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
+                )
+
+            else:
+                log.warning("No ids or filter provided for delete operation")
+
+        except Exception as e:
+            log.error(f"Error deleting from collection '{collection_name}': {e}")
+            raise
+
+    def reset(self) -> None:
+        """Reset the database by deleting all collections."""
+        try:
+            self.index.delete(delete_all=True)
+            log.info("All vectors successfully deleted from the index.")
+        except Exception as e:
+            log.error(f"Failed to reset Pinecone index: {e}")
+            raise

+ 1 - 1
backend/requirements.txt

@@ -50,7 +50,7 @@ qdrant-client~=1.12.0
 opensearch-py==2.8.0
 playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
 elasticsearch==8.17.1
-
+pinecone==6.0.2
 
 transformers
 sentence-transformers==3.3.1

+ 1 - 0
pyproject.toml

@@ -58,6 +58,7 @@ dependencies = [
     "opensearch-py==2.8.0",
     "playwright==1.49.1",
     "elasticsearch==8.17.1",
+    "pinecone==6.0.2",
 
     "transformers",
     "sentence-transformers==3.3.1",