Browse Source

Merge pull request #13712 from PVBLIC-F/dev

perf - Pinecone.py
Tim Jaeryang Baek 9 tháng trước cách đây
mục cha
commit
54dda08d39
1 tập tin đã thay đổi với 127 bổ sung26 xóa
  1. 127 26
      backend/open_webui/retrieval/vector/dbs/pinecone.py

+ 127 - 26
backend/open_webui/retrieval/vector/dbs/pinecone.py

@@ -1,6 +1,13 @@
 from typing import Optional, List, Dict, Any, Union
 import logging
-from pinecone import Pinecone, ServerlessSpec
+import time  # for measuring elapsed time
+from pinecone import ServerlessSpec
+
+import asyncio  # for async upserts
+import functools  # for partial binding in async tasks
+
+import concurrent.futures  # for parallel batch upserts
+from pinecone.grpc import PineconeGRPC  # use gRPC client for faster upserts
 
 from open_webui.retrieval.vector.main import (
     VectorDBBase,
@@ -40,8 +47,11 @@ class PineconeClient(VectorDBBase):
         self.metric = PINECONE_METRIC
         self.cloud = PINECONE_CLOUD
 
-        # Initialize Pinecone client
-        self.client = Pinecone(api_key=self.api_key)
+        # Initialize Pinecone gRPC client for improved performance
+        self.client = PineconeGRPC(api_key=self.api_key, environment=self.environment, cloud=self.cloud)
+
+        # Persistent executor for batch operations
+        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
 
         # Create index if it doesn't exist
         self._initialize_index()
@@ -191,28 +201,28 @@ class PineconeClient(VectorDBBase):
             log.warning("No items to insert")
             return
 
+        start_time = time.time()
+
         collection_name_with_prefix = self._get_collection_name_with_prefix(
             collection_name
         )
         points = self._create_points(items, collection_name_with_prefix)
 
-        # Insert in batches for better performance and reliability
+        # Parallelize batch inserts for performance
+        executor = self._executor
+        futures = []
         for i in range(0, len(points), BATCH_SIZE):
             batch = points[i : i + BATCH_SIZE]
+            futures.append(executor.submit(self.index.upsert, vectors=batch))
+        for future in concurrent.futures.as_completed(futures):
             try:
-                self.index.upsert(vectors=batch)
-                log.debug(
-                    f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
-                )
+                future.result()
             except Exception as e:
-                log.error(
-                    f"Error inserting batch into '{collection_name_with_prefix}': {e}"
-                )
+                log.error(f"Error inserting batch: {e}")
                 raise
-
-        log.info(
-            f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
-        )
+        elapsed = time.time() - start_time
+        log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
+        log.info(f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'")
 
     def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
         """Upsert (insert or update) vectors into a collection."""
@@ -220,28 +230,102 @@ class PineconeClient(VectorDBBase):
             log.warning("No items to upsert")
             return
 
+        start_time = time.time()
+
         collection_name_with_prefix = self._get_collection_name_with_prefix(
             collection_name
         )
         points = self._create_points(items, collection_name_with_prefix)
 
-        # Upsert in batches
+        # Parallelize batch upserts for performance
+        executor = self._executor
+        futures = []
         for i in range(0, len(points), BATCH_SIZE):
             batch = points[i : i + BATCH_SIZE]
+            futures.append(executor.submit(self.index.upsert, vectors=batch))
+        for future in concurrent.futures.as_completed(futures):
             try:
-                self.index.upsert(vectors=batch)
-                log.debug(
-                    f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
-                )
+                future.result()
             except Exception as e:
-                log.error(
-                    f"Error upserting batch into '{collection_name_with_prefix}': {e}"
-                )
+                log.error(f"Error upserting batch: {e}")
                 raise
+        elapsed = time.time() - start_time
+        log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
+        log.info(f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'")
 
-        log.info(
-            f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
-        )
+    async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Async version of insert using asyncio and run_in_executor for improved performance."""
+        if not items:
+            log.warning("No items to insert")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Create batches
+        batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
+        loop = asyncio.get_event_loop()
+        tasks = [
+            loop.run_in_executor(
+                None,
+                functools.partial(self.index.upsert, vectors=batch)
+            )
+            for batch in batches
+        ]
+        results = await asyncio.gather(*tasks, return_exceptions=True)
+        for result in results:
+            if isinstance(result, Exception):
+                log.error(f"Error in async insert batch: {result}")
+                raise result
+        log.info(f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
+
+    async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Async version of upsert using asyncio and run_in_executor for improved performance."""
+        if not items:
+            log.warning("No items to upsert")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Create batches
+        batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
+        loop = asyncio.get_event_loop()
+        tasks = [
+            loop.run_in_executor(
+                None,
+                functools.partial(self.index.upsert, vectors=batch)
+            )
+            for batch in batches
+        ]
+        results = await asyncio.gather(*tasks, return_exceptions=True)
+        for result in results:
+            if isinstance(result, Exception):
+                log.error(f"Error in async upsert batch: {result}")
+                raise result
+        log.info(f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
+
+    def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Perform a streaming upsert over gRPC for performance testing."""
+        if not items:
+            log.warning("No items to upsert via streaming")
+            return
+
+        collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
+        points = self._create_points(items, collection_name_with_prefix)
+
+        # Open a streaming upsert channel
+        stream = self.index.streaming_upsert()
+        try:
+            for point in points:
+                # send each point over the stream
+                stream.send(point)
+            # close the stream to finalize
+            stream.close()
+            log.info(f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'")
+        except Exception as e:
+            log.error(f"Error during streaming upsert: {e}")
+            raise
 
     def search(
         self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
@@ -410,3 +494,20 @@ class PineconeClient(VectorDBBase):
         except Exception as e:
             log.error(f"Failed to reset Pinecone index: {e}")
             raise
+
+    def close(self):
+        """Shut down the gRPC channel and thread pool."""
+        try:
+            self.client.close()
+            log.info("Pinecone gRPC channel closed.")
+        except Exception as e:
+            log.warning(f"Failed to close Pinecone gRPC channel: {e}")
+        self._executor.shutdown(wait=True)
+
+    def __enter__(self):
+        """Enter context manager."""
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        """Exit context manager, ensuring resources are cleaned up."""
+        self.close()