pinecone.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. import time # for measuring elapsed time
  4. from pinecone import Pinecone, ServerlessSpec
  5. # Add gRPC support for better performance (Pinecone best practice)
  6. try:
  7. from pinecone.grpc import PineconeGRPC
  8. GRPC_AVAILABLE = True
  9. except ImportError:
  10. GRPC_AVAILABLE = False
  11. import asyncio # for async upserts
  12. import functools # for partial binding in async tasks
  13. import concurrent.futures # for parallel batch upserts
  14. import random # for jitter in retry backoff
  15. from open_webui.retrieval.vector.main import (
  16. VectorDBBase,
  17. VectorItem,
  18. SearchResult,
  19. GetResult,
  20. )
  21. from open_webui.config import (
  22. PINECONE_API_KEY,
  23. PINECONE_ENVIRONMENT,
  24. PINECONE_INDEX_NAME,
  25. PINECONE_DIMENSION,
  26. PINECONE_METRIC,
  27. PINECONE_CLOUD,
  28. )
  29. from open_webui.env import SRC_LOG_LEVELS
  30. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  31. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  32. log = logging.getLogger(__name__)
  33. log.setLevel(SRC_LOG_LEVELS["RAG"])
  34. class PineconeClient(VectorDBBase):
  35. def __init__(self):
  36. self.collection_prefix = "open-webui"
  37. # Validate required configuration
  38. self._validate_config()
  39. # Store configuration values
  40. self.api_key = PINECONE_API_KEY
  41. self.environment = PINECONE_ENVIRONMENT
  42. self.index_name = PINECONE_INDEX_NAME
  43. self.dimension = PINECONE_DIMENSION
  44. self.metric = PINECONE_METRIC
  45. self.cloud = PINECONE_CLOUD
  46. # Initialize Pinecone client for improved performance
  47. if GRPC_AVAILABLE:
  48. # Use gRPC client for better performance (Pinecone recommendation)
  49. self.client = PineconeGRPC(
  50. api_key=self.api_key,
  51. pool_threads=20, # Improved connection pool size
  52. timeout=30 # Reasonable timeout for operations
  53. )
  54. self.using_grpc = True
  55. log.info("Using Pinecone gRPC client for optimal performance")
  56. else:
  57. # Fallback to HTTP client with enhanced connection pooling
  58. self.client = Pinecone(
  59. api_key=self.api_key,
  60. pool_threads=20, # Improved connection pool size
  61. timeout=30 # Reasonable timeout for operations
  62. )
  63. self.using_grpc = False
  64. log.info("Using Pinecone HTTP client (gRPC not available)")
  65. # Persistent executor for batch operations
  66. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
  67. # Create index if it doesn't exist
  68. self._initialize_index()
  69. def _validate_config(self) -> None:
  70. """Validate that all required configuration variables are set."""
  71. missing_vars = []
  72. if not PINECONE_API_KEY:
  73. missing_vars.append("PINECONE_API_KEY")
  74. if not PINECONE_ENVIRONMENT:
  75. missing_vars.append("PINECONE_ENVIRONMENT")
  76. if not PINECONE_INDEX_NAME:
  77. missing_vars.append("PINECONE_INDEX_NAME")
  78. if not PINECONE_DIMENSION:
  79. missing_vars.append("PINECONE_DIMENSION")
  80. if not PINECONE_CLOUD:
  81. missing_vars.append("PINECONE_CLOUD")
  82. if missing_vars:
  83. raise ValueError(
  84. f"Required configuration missing: {', '.join(missing_vars)}"
  85. )
  86. def _initialize_index(self) -> None:
  87. """Initialize the Pinecone index."""
  88. try:
  89. # Check if index exists
  90. if self.index_name not in self.client.list_indexes().names():
  91. log.info(f"Creating Pinecone index '{self.index_name}'...")
  92. self.client.create_index(
  93. name=self.index_name,
  94. dimension=self.dimension,
  95. metric=self.metric,
  96. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  97. )
  98. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  99. else:
  100. log.info(f"Using existing Pinecone index '{self.index_name}'")
  101. # Connect to the index
  102. self.index = self.client.Index(
  103. self.index_name,
  104. pool_threads=20, # Enhanced connection pool for index operations
  105. )
  106. except Exception as e:
  107. log.error(f"Failed to initialize Pinecone index: {e}")
  108. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  109. def _retry_pinecone_operation(self, operation_func, max_retries=3):
  110. """Retry Pinecone operations with exponential backoff for rate limits and network issues."""
  111. for attempt in range(max_retries):
  112. try:
  113. return operation_func()
  114. except Exception as e:
  115. error_str = str(e).lower()
  116. # Check if it's a retryable error (rate limits, network issues, timeouts)
  117. is_retryable = any(keyword in error_str for keyword in [
  118. 'rate limit', 'quota', 'timeout', 'network', 'connection',
  119. 'unavailable', 'internal error', '429', '500', '502', '503', '504'
  120. ])
  121. if not is_retryable or attempt == max_retries - 1:
  122. # Don't retry for non-retryable errors or on final attempt
  123. raise
  124. # Exponential backoff with jitter
  125. delay = (2 ** attempt) + random.uniform(0, 1)
  126. log.warning(f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay:.2f}s: {e}")
  127. time.sleep(delay)
  128. def _create_points(
  129. self, items: List[VectorItem], collection_name_with_prefix: str
  130. ) -> List[Dict[str, Any]]:
  131. """Convert VectorItem objects to Pinecone point format."""
  132. points = []
  133. for item in items:
  134. # Start with any existing metadata or an empty dict
  135. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  136. # Add text to metadata if available
  137. if "text" in item:
  138. metadata["text"] = item["text"]
  139. # Always add collection_name to metadata for filtering
  140. metadata["collection_name"] = collection_name_with_prefix
  141. point = {
  142. "id": item["id"],
  143. "values": item["vector"],
  144. "metadata": metadata,
  145. }
  146. points.append(point)
  147. return points
  148. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  149. """Get the collection name with prefix."""
  150. return f"{self.collection_prefix}_{collection_name}"
  151. def _normalize_distance(self, score: float) -> float:
  152. """Normalize distance score based on the metric used."""
  153. if self.metric.lower() == "cosine":
  154. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  155. return (score + 1.0) / 2.0
  156. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  157. # These are already suitable for ranking (smaller is better for Euclidean)
  158. return score
  159. else:
  160. # For other metrics, use as is
  161. return score
  162. def _result_to_get_result(self, matches: list) -> GetResult:
  163. """Convert Pinecone matches to GetResult format."""
  164. ids = []
  165. documents = []
  166. metadatas = []
  167. for match in matches:
  168. metadata = getattr(match, "metadata", {}) or {}
  169. ids.append(match.id if hasattr(match, "id") else match["id"])
  170. documents.append(metadata.get("text", ""))
  171. metadatas.append(metadata)
  172. return GetResult(
  173. **{
  174. "ids": [ids],
  175. "documents": [documents],
  176. "metadatas": [metadatas],
  177. }
  178. )
  179. def has_collection(self, collection_name: str) -> bool:
  180. """Check if a collection exists by searching for at least one item."""
  181. collection_name_with_prefix = self._get_collection_name_with_prefix(
  182. collection_name
  183. )
  184. try:
  185. # Search for at least 1 item with this collection name in metadata
  186. response = self.index.query(
  187. vector=[0.0] * self.dimension, # dummy vector
  188. top_k=1,
  189. filter={"collection_name": collection_name_with_prefix},
  190. include_metadata=False,
  191. )
  192. matches = getattr(response, "matches", []) or []
  193. return len(matches) > 0
  194. except Exception as e:
  195. log.exception(
  196. f"Error checking collection '{collection_name_with_prefix}': {e}"
  197. )
  198. return False
  199. def delete_collection(self, collection_name: str) -> None:
  200. """Delete a collection by removing all vectors with the collection name in metadata."""
  201. collection_name_with_prefix = self._get_collection_name_with_prefix(
  202. collection_name
  203. )
  204. try:
  205. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  206. log.info(
  207. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  208. )
  209. except Exception as e:
  210. log.warning(
  211. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  212. )
  213. raise
  214. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  215. """Insert vectors into a collection."""
  216. if not items:
  217. log.warning("No items to insert")
  218. return
  219. start_time = time.time()
  220. collection_name_with_prefix = self._get_collection_name_with_prefix(
  221. collection_name
  222. )
  223. points = self._create_points(items, collection_name_with_prefix)
  224. # Parallelize batch inserts for performance
  225. executor = self._executor
  226. futures = []
  227. for i in range(0, len(points), BATCH_SIZE):
  228. batch = points[i : i + BATCH_SIZE]
  229. futures.append(executor.submit(self.index.upsert, vectors=batch))
  230. for future in concurrent.futures.as_completed(futures):
  231. try:
  232. future.result()
  233. except Exception as e:
  234. log.error(f"Error inserting batch: {e}")
  235. raise
  236. elapsed = time.time() - start_time
  237. log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
  238. log.info(
  239. f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
  240. )
  241. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  242. """Upsert (insert or update) vectors into a collection."""
  243. if not items:
  244. log.warning("No items to upsert")
  245. return
  246. start_time = time.time()
  247. collection_name_with_prefix = self._get_collection_name_with_prefix(
  248. collection_name
  249. )
  250. points = self._create_points(items, collection_name_with_prefix)
  251. # Parallelize batch upserts for performance
  252. executor = self._executor
  253. futures = []
  254. for i in range(0, len(points), BATCH_SIZE):
  255. batch = points[i : i + BATCH_SIZE]
  256. futures.append(executor.submit(self.index.upsert, vectors=batch))
  257. for future in concurrent.futures.as_completed(futures):
  258. try:
  259. future.result()
  260. except Exception as e:
  261. log.error(f"Error upserting batch: {e}")
  262. raise
  263. elapsed = time.time() - start_time
  264. log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
  265. log.info(
  266. f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
  267. )
  268. async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  269. """Async version of insert using asyncio and run_in_executor for improved performance."""
  270. if not items:
  271. log.warning("No items to insert")
  272. return
  273. collection_name_with_prefix = self._get_collection_name_with_prefix(
  274. collection_name
  275. )
  276. points = self._create_points(items, collection_name_with_prefix)
  277. # Create batches
  278. batches = [
  279. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  280. ]
  281. loop = asyncio.get_event_loop()
  282. tasks = [
  283. loop.run_in_executor(
  284. None, functools.partial(self.index.upsert, vectors=batch)
  285. )
  286. for batch in batches
  287. ]
  288. results = await asyncio.gather(*tasks, return_exceptions=True)
  289. for result in results:
  290. if isinstance(result, Exception):
  291. log.error(f"Error in async insert batch: {result}")
  292. raise result
  293. log.info(
  294. f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
  295. )
  296. async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  297. """Async version of upsert using asyncio and run_in_executor for improved performance."""
  298. if not items:
  299. log.warning("No items to upsert")
  300. return
  301. collection_name_with_prefix = self._get_collection_name_with_prefix(
  302. collection_name
  303. )
  304. points = self._create_points(items, collection_name_with_prefix)
  305. # Create batches
  306. batches = [
  307. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  308. ]
  309. loop = asyncio.get_event_loop()
  310. tasks = [
  311. loop.run_in_executor(
  312. None, functools.partial(self.index.upsert, vectors=batch)
  313. )
  314. for batch in batches
  315. ]
  316. results = await asyncio.gather(*tasks, return_exceptions=True)
  317. for result in results:
  318. if isinstance(result, Exception):
  319. log.error(f"Error in async upsert batch: {result}")
  320. raise result
  321. log.info(
  322. f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
  323. )
  324. def search(
  325. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  326. ) -> Optional[SearchResult]:
  327. """Search for similar vectors in a collection."""
  328. if not vectors or not vectors[0]:
  329. log.warning("No vectors provided for search")
  330. return None
  331. collection_name_with_prefix = self._get_collection_name_with_prefix(
  332. collection_name
  333. )
  334. if limit is None or limit <= 0:
  335. limit = NO_LIMIT
  336. try:
  337. # Search using the first vector (assuming this is the intended behavior)
  338. query_vector = vectors[0]
  339. # Perform the search
  340. query_response = self.index.query(
  341. vector=query_vector,
  342. top_k=limit,
  343. include_metadata=True,
  344. filter={"collection_name": collection_name_with_prefix},
  345. )
  346. matches = getattr(query_response, "matches", []) or []
  347. if not matches:
  348. # Return empty result if no matches
  349. return SearchResult(
  350. ids=[[]],
  351. documents=[[]],
  352. metadatas=[[]],
  353. distances=[[]],
  354. )
  355. # Convert to GetResult format
  356. get_result = self._result_to_get_result(matches)
  357. # Calculate normalized distances based on metric
  358. distances = [
  359. [
  360. self._normalize_distance(getattr(match, "score", 0.0))
  361. for match in matches
  362. ]
  363. ]
  364. return SearchResult(
  365. ids=get_result.ids,
  366. documents=get_result.documents,
  367. metadatas=get_result.metadatas,
  368. distances=distances,
  369. )
  370. except Exception as e:
  371. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  372. return None
  373. def query(
  374. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  375. ) -> Optional[GetResult]:
  376. """Query vectors by metadata filter."""
  377. collection_name_with_prefix = self._get_collection_name_with_prefix(
  378. collection_name
  379. )
  380. if limit is None or limit <= 0:
  381. limit = NO_LIMIT
  382. try:
  383. # Create a zero vector for the dimension as Pinecone requires a vector
  384. zero_vector = [0.0] * self.dimension
  385. # Combine user filter with collection_name
  386. pinecone_filter = {"collection_name": collection_name_with_prefix}
  387. if filter:
  388. pinecone_filter.update(filter)
  389. # Perform metadata-only query
  390. query_response = self.index.query(
  391. vector=zero_vector,
  392. filter=pinecone_filter,
  393. top_k=limit,
  394. include_metadata=True,
  395. )
  396. matches = getattr(query_response, "matches", []) or []
  397. return self._result_to_get_result(matches)
  398. except Exception as e:
  399. log.error(f"Error querying collection '{collection_name}': {e}")
  400. return None
  401. def get(self, collection_name: str) -> Optional[GetResult]:
  402. """Get all vectors in a collection."""
  403. collection_name_with_prefix = self._get_collection_name_with_prefix(
  404. collection_name
  405. )
  406. try:
  407. # Use a zero vector for fetching all entries
  408. zero_vector = [0.0] * self.dimension
  409. # Add filter to only get vectors for this collection
  410. query_response = self.index.query(
  411. vector=zero_vector,
  412. top_k=NO_LIMIT,
  413. include_metadata=True,
  414. filter={"collection_name": collection_name_with_prefix},
  415. )
  416. matches = getattr(query_response, "matches", []) or []
  417. return self._result_to_get_result(matches)
  418. except Exception as e:
  419. log.error(f"Error getting collection '{collection_name}': {e}")
  420. return None
  421. def delete(
  422. self,
  423. collection_name: str,
  424. ids: Optional[List[str]] = None,
  425. filter: Optional[Dict] = None,
  426. ) -> None:
  427. """Delete vectors by IDs or filter."""
  428. collection_name_with_prefix = self._get_collection_name_with_prefix(
  429. collection_name
  430. )
  431. try:
  432. if ids:
  433. # Delete by IDs (in batches for large deletions)
  434. for i in range(0, len(ids), BATCH_SIZE):
  435. batch_ids = ids[i : i + BATCH_SIZE]
  436. # Note: When deleting by ID, we can't filter by collection_name
  437. # This is a limitation of Pinecone - be careful with ID uniqueness
  438. self.index.delete(ids=batch_ids)
  439. log.debug(
  440. f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
  441. )
  442. log.info(
  443. f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
  444. )
  445. elif filter:
  446. # Combine user filter with collection_name
  447. pinecone_filter = {"collection_name": collection_name_with_prefix}
  448. if filter:
  449. pinecone_filter.update(filter)
  450. # Delete by metadata filter
  451. self.index.delete(filter=pinecone_filter)
  452. log.info(
  453. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  454. )
  455. else:
  456. log.warning("No ids or filter provided for delete operation")
  457. except Exception as e:
  458. log.error(f"Error deleting from collection '{collection_name}': {e}")
  459. raise
  460. def reset(self) -> None:
  461. """Reset the database by deleting all collections."""
  462. try:
  463. self.index.delete(delete_all=True)
  464. log.info("All vectors successfully deleted from the index.")
  465. except Exception as e:
  466. log.error(f"Failed to reset Pinecone index: {e}")
  467. raise
  468. def close(self):
  469. """Shut down resources."""
  470. try:
  471. # The new Pinecone client doesn't need explicit closing
  472. pass
  473. except Exception as e:
  474. log.warning(f"Failed to clean up Pinecone resources: {e}")
  475. self._executor.shutdown(wait=True)
  476. def __enter__(self):
  477. """Enter context manager."""
  478. return self
  479. def __exit__(self, exc_type, exc_val, exc_tb):
  480. """Exit context manager, ensuring resources are cleaned up."""
  481. self.close()