浏览代码

Update pinecone.py

Refactor and added debug
PVBLIC Foundation 5 月之前
父节点
当前提交
12c2138982
共有 1 个文件被更改,包括 121 次插入37 次删除
  1. 121 37
      backend/open_webui/retrieval/vector/dbs/pinecone.py

+ 121 - 37
backend/open_webui/retrieval/vector/dbs/pinecone.py

@@ -1,8 +1,14 @@
 from typing import Optional, List, Dict, Any, Union
 import logging
-import asyncio
+import time  # for measuring elapsed time
 from pinecone import Pinecone, ServerlessSpec
 
+import asyncio  # for async upserts
+import functools  # for partial binding in async tasks
+
+import concurrent.futures  # for parallel batch upserts
+from pinecone.grpc import PineconeGRPC  # use gRPC client for faster upserts
+
 from open_webui.retrieval.vector.main import (
     VectorDBBase,
     VectorItem,
@@ -20,7 +26,7 @@ from open_webui.config import (
 from open_webui.env import SRC_LOG_LEVELS
 
 NO_LIMIT = 10000  # Reasonable limit to avoid overwhelming the system
-BATCH_SIZE = 200  # Recommended batch size for Pinecone operations
+BATCH_SIZE = 100  # Recommended batch size for Pinecone operations
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -41,8 +47,11 @@ class PineconeClient(VectorDBBase):
         self.metric = PINECONE_METRIC
         self.cloud = PINECONE_CLOUD
 
-        # Initialize Pinecone client
-        self.client = Pinecone(api_key=self.api_key)
+        # Initialize Pinecone gRPC client for improved performance
+        self.client = PineconeGRPC(api_key=self.api_key, environment=self.environment, cloud=self.cloud)
+
+        # Persistent executor for batch operations
+        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
 
         # Create index if it doesn't exist
         self._initialize_index()
@@ -186,65 +195,137 @@ class PineconeClient(VectorDBBase):
             )
             raise
 
-    async def insert(self, collection_name: str, items: List[VectorItem]) -> None:
+    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
         """Insert vectors into a collection."""
-        import time
         if not items:
             log.warning("No items to insert")
             return
 
+        start_time = time.time()
+
         collection_name_with_prefix = self._get_collection_name_with_prefix(
             collection_name
         )
         points = self._create_points(items, collection_name_with_prefix)
 
-        # Insert in batches for better performance and reliability
+        # Parallelize batch inserts for performance
+        executor = self._executor
+        futures = []
         for i in range(0, len(points), BATCH_SIZE):
             batch = points[i : i + BATCH_SIZE]
+            futures.append(executor.submit(self.index.upsert, vectors=batch))
+        for future in concurrent.futures.as_completed(futures):
             try:
-                start = time.time()
-                await asyncio.to_thread(self.index.upsert, vectors=batch)
-                elapsed = int((time.time() - start) * 1000)
-                # Log line removed as requested
+                future.result()
             except Exception as e:
-                log.error(
-                    f"Error inserting batch into '{collection_name_with_prefix}': {e}"
-                )
+                log.error(f"Error inserting batch: {e}")
                 raise
+        elapsed = time.time() - start_time
+        log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
+        log.info(f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'")
 
-        log.info(
-            f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
-        )
-
-    async def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
         """Upsert (insert or update) vectors into a collection."""
-        import time
         if not items:
             log.warning("No items to upsert")
             return
 
+        start_time = time.time()
+
         collection_name_with_prefix = self._get_collection_name_with_prefix(
             collection_name
         )
         points = self._create_points(items, collection_name_with_prefix)
 
-        # Upsert in batches
+        # Parallelize batch upserts for performance
+        executor = self._executor
+        futures = []
         for i in range(0, len(points), BATCH_SIZE):
             batch = points[i : i + BATCH_SIZE]
+            futures.append(executor.submit(self.index.upsert, vectors=batch))
+        for future in concurrent.futures.as_completed(futures):
             try:
-                start = time.time()
-                await asyncio.to_thread(self.index.upsert, vectors=batch)
-                elapsed = int((time.time() - start) * 1000)
-                # Log line removed as requested
+                future.result()
             except Exception as e:
-                log.error(
-                    f"Error upserting batch into '{collection_name_with_prefix}': {e}"
-                )
+                log.error(f"Error upserting batch: {e}")
                 raise
+        elapsed = time.time() - start_time
+        log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
+        log.info(f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'")
 
-        log.info(
-            f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
-        )
+    async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Async version of insert using asyncio and run_in_executor for improved performance."""
+        if not items:
+            log.warning("No items to insert")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Create batches
+        batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
+        loop = asyncio.get_event_loop()
+        tasks = [
+            loop.run_in_executor(
+                None,
+                functools.partial(self.index.upsert, vectors=batch)
+            )
+            for batch in batches
+        ]
+        results = await asyncio.gather(*tasks, return_exceptions=True)
+        for result in results:
+            if isinstance(result, Exception):
+                log.error(f"Error in async insert batch: {result}")
+                raise result
+        log.info(f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
+
+    async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Async version of upsert using asyncio and run_in_executor for improved performance."""
+        if not items:
+            log.warning("No items to upsert")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Create batches
+        batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
+        loop = asyncio.get_event_loop()
+        tasks = [
+            loop.run_in_executor(
+                None,
+                functools.partial(self.index.upsert, vectors=batch)
+            )
+            for batch in batches
+        ]
+        results = await asyncio.gather(*tasks, return_exceptions=True)
+        for result in results:
+            if isinstance(result, Exception):
+                log.error(f"Error in async upsert batch: {result}")
+                raise result
+        log.info(f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
+
+    def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Perform a streaming upsert over gRPC for performance testing."""
+        if not items:
+            log.warning("No items to upsert via streaming")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Open a streaming upsert channel
+        stream = self.index.streaming_upsert()
+        try:
+            for point in points:
+                # send each point over the stream
+                stream.send(point)
+            # close the stream to finalize
+            stream.close()
+            log.info(f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'")
+        except Exception as e:
+            log.error(f"Error during streaming upsert: {e}")
+            raise
 
     def search(
         self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
@@ -361,14 +442,13 @@ class PineconeClient(VectorDBBase):
             log.error(f"Error getting collection '{collection_name}': {e}")
             return None
 
-    async def delete(
+    def delete(
         self,
         collection_name: str,
         ids: Optional[List[str]] = None,
         filter: Optional[Dict] = None,
     ) -> None:
         """Delete vectors by IDs or filter."""
-        import time
         collection_name_with_prefix = self._get_collection_name_with_prefix(
             collection_name
         )
@@ -380,10 +460,10 @@ class PineconeClient(VectorDBBase):
                     batch_ids = ids[i : i + BATCH_SIZE]
                     # Note: When deleting by ID, we can't filter by collection_name
                     # This is a limitation of Pinecone - be careful with ID uniqueness
-                    start = time.time()
-                    await asyncio.to_thread(self.index.delete, ids=batch_ids)
-                    elapsed = int((time.time() - start) * 1000)
-                    # Log line removed as requested
+                    self.index.delete(ids=batch_ids)
+                    log.debug(
+                        f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
+                    )
                 log.info(
                     f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
                 )
@@ -414,3 +494,7 @@ class PineconeClient(VectorDBBase):
         except Exception as e:
             log.error(f"Failed to reset Pinecone index: {e}")
             raise
+
+    def close(self):
+        """Shut down the thread pool."""
+        self._executor.shutdown(wait=True)