ソースを参照

Merge pull request #14147 from PVBLIC-F/dev

perf Update pinecone.py
Tim Jaeryang Baek 4 ヶ月 前
コミット
0eda03bd3c
1 ファイル変更20 行追加45 行削除
  1. 20 45
      backend/open_webui/retrieval/vector/dbs/pinecone.py

+ 20 - 45
backend/open_webui/retrieval/vector/dbs/pinecone.py

@@ -1,13 +1,12 @@
 from typing import Optional, List, Dict, Any, Union
 import logging
 import time  # for measuring elapsed time
-from pinecone import ServerlessSpec
+from pinecone import Pinecone, ServerlessSpec
 
 import asyncio  # for async upserts
 import functools  # for partial binding in async tasks
 
 import concurrent.futures  # for parallel batch upserts
-from pinecone.grpc import PineconeGRPC  # use gRPC client for faster upserts
 
 from open_webui.retrieval.vector.main import (
     VectorDBBase,
@@ -47,10 +46,8 @@ class PineconeClient(VectorDBBase):
         self.metric = PINECONE_METRIC
         self.cloud = PINECONE_CLOUD
 
-        # Initialize Pinecone gRPC client for improved performance
-        self.client = PineconeGRPC(
-            api_key=self.api_key, environment=self.environment, cloud=self.cloud
-        )
+        # Initialize Pinecone client for improved performance
+        self.client = Pinecone(api_key=self.api_key)
 
         # Persistent executor for batch operations
         self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@@ -147,8 +144,8 @@ class PineconeClient(VectorDBBase):
         metadatas = []
 
         for match in matches:
-            metadata = match.get("metadata", {})
-            ids.append(match["id"])
+            metadata = getattr(match, "metadata", {}) or {}
+            ids.append(match.id if hasattr(match, "id") else match["id"])
             documents.append(metadata.get("text", ""))
             metadatas.append(metadata)
 
@@ -174,7 +171,8 @@ class PineconeClient(VectorDBBase):
                 filter={"collection_name": collection_name_with_prefix},
                 include_metadata=False,
             )
-            return len(response.matches) > 0
+            matches = getattr(response, "matches", []) or []
+            return len(matches) > 0
         except Exception as e:
             log.exception(
                 f"Error checking collection '{collection_name_with_prefix}': {e}"
@@ -321,32 +319,6 @@ class PineconeClient(VectorDBBase):
             f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
         )
 
-    def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
-        """Perform a streaming upsert over gRPC for performance testing."""
-        if not items:
-            log.warning("No items to upsert via streaming")
-            return
-
-        collection_name_with_prefix = self._get_collection_name_with_prefix(
-            collection_name
-        )
-        points = self._create_points(items, collection_name_with_prefix)
-
-        # Open a streaming upsert channel
-        stream = self.index.streaming_upsert()
-        try:
-            for point in points:
-                # send each point over the stream
-                stream.send(point)
-            # close the stream to finalize
-            stream.close()
-            log.info(
-                f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
-            )
-        except Exception as e:
-            log.error(f"Error during streaming upsert: {e}")
-            raise
-
     def search(
         self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
     ) -> Optional[SearchResult]:
@@ -374,7 +346,8 @@ class PineconeClient(VectorDBBase):
                 filter={"collection_name": collection_name_with_prefix},
             )
 
-            if not query_response.matches:
+            matches = getattr(query_response, "matches", []) or []
+            if not matches:
                 # Return empty result if no matches
                 return SearchResult(
                     ids=[[]],
@@ -384,13 +357,13 @@ class PineconeClient(VectorDBBase):
                 )
 
             # Convert to GetResult format
-            get_result = self._result_to_get_result(query_response.matches)
+            get_result = self._result_to_get_result(matches)
 
             # Calculate normalized distances based on metric
             distances = [
                 [
-                    self._normalize_distance(match.score)
-                    for match in query_response.matches
+                    self._normalize_distance(getattr(match, "score", 0.0))
+                    for match in matches
                 ]
             ]
 
@@ -432,7 +405,8 @@ class PineconeClient(VectorDBBase):
                 include_metadata=True,
             )
 
-            return self._result_to_get_result(query_response.matches)
+            matches = getattr(query_response, "matches", []) or []
+            return self._result_to_get_result(matches)
 
         except Exception as e:
             log.error(f"Error querying collection '{collection_name}': {e}")
@@ -456,7 +430,8 @@ class PineconeClient(VectorDBBase):
                 filter={"collection_name": collection_name_with_prefix},
             )
 
-            return self._result_to_get_result(query_response.matches)
+            matches = getattr(query_response, "matches", []) or []
+            return self._result_to_get_result(matches)
 
         except Exception as e:
             log.error(f"Error getting collection '{collection_name}': {e}")
@@ -516,12 +491,12 @@ class PineconeClient(VectorDBBase):
             raise
 
     def close(self):
-        """Shut down the gRPC channel and thread pool."""
+        """Shut down resources."""
         try:
-            self.client.close()
-            log.info("Pinecone gRPC channel closed.")
+            # The new Pinecone client doesn't need explicit closing
+            pass
         except Exception as e:
-            log.warning(f"Failed to close Pinecone gRPC channel: {e}")
+            log.warning(f"Failed to clean up Pinecone resources: {e}")
         self._executor.shutdown(wait=True)
 
     def __enter__(self):