pinecone.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. from typing import Optional, List, Dict, Any, Union
  2. import logging
  3. import time # for measuring elapsed time
  4. from pinecone import Pinecone, ServerlessSpec
  5. # Add gRPC support for better performance (Pinecone best practice)
  6. try:
  7. from pinecone.grpc import PineconeGRPC
  8. GRPC_AVAILABLE = True
  9. except ImportError:
  10. GRPC_AVAILABLE = False
  11. import asyncio # for async upserts
  12. import functools # for partial binding in async tasks
  13. import concurrent.futures # for parallel batch upserts
  14. import random # for jitter in retry backoff
  15. from open_webui.retrieval.vector.main import (
  16. VectorDBBase,
  17. VectorItem,
  18. SearchResult,
  19. GetResult,
  20. )
  21. from open_webui.config import (
  22. PINECONE_API_KEY,
  23. PINECONE_ENVIRONMENT,
  24. PINECONE_INDEX_NAME,
  25. PINECONE_DIMENSION,
  26. PINECONE_METRIC,
  27. PINECONE_CLOUD,
  28. )
  29. from open_webui.env import SRC_LOG_LEVELS
  30. from open_webui.retrieval.vector.utils import stringify_metadata
  31. NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
  32. BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
  33. log = logging.getLogger(__name__)
  34. log.setLevel(SRC_LOG_LEVELS["RAG"])
  35. class PineconeClient(VectorDBBase):
  36. def __init__(self):
  37. self.collection_prefix = "open-webui"
  38. # Validate required configuration
  39. self._validate_config()
  40. # Store configuration values
  41. self.api_key = PINECONE_API_KEY
  42. self.environment = PINECONE_ENVIRONMENT
  43. self.index_name = PINECONE_INDEX_NAME
  44. self.dimension = PINECONE_DIMENSION
  45. self.metric = PINECONE_METRIC
  46. self.cloud = PINECONE_CLOUD
  47. # Initialize Pinecone client for improved performance
  48. if GRPC_AVAILABLE:
  49. # Use gRPC client for better performance (Pinecone recommendation)
  50. self.client = PineconeGRPC(
  51. api_key=self.api_key,
  52. pool_threads=20, # Improved connection pool size
  53. timeout=30, # Reasonable timeout for operations
  54. )
  55. self.using_grpc = True
  56. log.info("Using Pinecone gRPC client for optimal performance")
  57. else:
  58. # Fallback to HTTP client with enhanced connection pooling
  59. self.client = Pinecone(
  60. api_key=self.api_key,
  61. pool_threads=20, # Improved connection pool size
  62. timeout=30, # Reasonable timeout for operations
  63. )
  64. self.using_grpc = False
  65. log.info("Using Pinecone HTTP client (gRPC not available)")
  66. # Persistent executor for batch operations
  67. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
  68. # Create index if it doesn't exist
  69. self._initialize_index()
  70. def _validate_config(self) -> None:
  71. """Validate that all required configuration variables are set."""
  72. missing_vars = []
  73. if not PINECONE_API_KEY:
  74. missing_vars.append("PINECONE_API_KEY")
  75. if not PINECONE_ENVIRONMENT:
  76. missing_vars.append("PINECONE_ENVIRONMENT")
  77. if not PINECONE_INDEX_NAME:
  78. missing_vars.append("PINECONE_INDEX_NAME")
  79. if not PINECONE_DIMENSION:
  80. missing_vars.append("PINECONE_DIMENSION")
  81. if not PINECONE_CLOUD:
  82. missing_vars.append("PINECONE_CLOUD")
  83. if missing_vars:
  84. raise ValueError(
  85. f"Required configuration missing: {', '.join(missing_vars)}"
  86. )
  87. def _initialize_index(self) -> None:
  88. """Initialize the Pinecone index."""
  89. try:
  90. # Check if index exists
  91. if self.index_name not in self.client.list_indexes().names():
  92. log.info(f"Creating Pinecone index '{self.index_name}'...")
  93. self.client.create_index(
  94. name=self.index_name,
  95. dimension=self.dimension,
  96. metric=self.metric,
  97. spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
  98. )
  99. log.info(f"Successfully created Pinecone index '{self.index_name}'")
  100. else:
  101. log.info(f"Using existing Pinecone index '{self.index_name}'")
  102. # Connect to the index
  103. self.index = self.client.Index(
  104. self.index_name,
  105. pool_threads=20, # Enhanced connection pool for index operations
  106. )
  107. except Exception as e:
  108. log.error(f"Failed to initialize Pinecone index: {e}")
  109. raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
  110. def _retry_pinecone_operation(self, operation_func, max_retries=3):
  111. """Retry Pinecone operations with exponential backoff for rate limits and network issues."""
  112. for attempt in range(max_retries):
  113. try:
  114. return operation_func()
  115. except Exception as e:
  116. error_str = str(e).lower()
  117. # Check if it's a retryable error (rate limits, network issues, timeouts)
  118. is_retryable = any(
  119. keyword in error_str
  120. for keyword in [
  121. "rate limit",
  122. "quota",
  123. "timeout",
  124. "network",
  125. "connection",
  126. "unavailable",
  127. "internal error",
  128. "429",
  129. "500",
  130. "502",
  131. "503",
  132. "504",
  133. ]
  134. )
  135. if not is_retryable or attempt == max_retries - 1:
  136. # Don't retry for non-retryable errors or on final attempt
  137. raise
  138. # Exponential backoff with jitter
  139. delay = (2**attempt) + random.uniform(0, 1)
  140. log.warning(
  141. f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
  142. f"retrying in {delay:.2f}s: {e}"
  143. )
  144. time.sleep(delay)
  145. def _create_points(
  146. self, items: List[VectorItem], collection_name_with_prefix: str
  147. ) -> List[Dict[str, Any]]:
  148. """Convert VectorItem objects to Pinecone point format."""
  149. points = []
  150. for item in items:
  151. # Start with any existing metadata or an empty dict
  152. metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
  153. # Add text to metadata if available
  154. if "text" in item:
  155. metadata["text"] = item["text"]
  156. # Always add collection_name to metadata for filtering
  157. metadata["collection_name"] = collection_name_with_prefix
  158. point = {
  159. "id": item["id"],
  160. "values": item["vector"],
  161. "metadata": stringify_metadata(metadata),
  162. }
  163. points.append(point)
  164. return points
  165. def _get_collection_name_with_prefix(self, collection_name: str) -> str:
  166. """Get the collection name with prefix."""
  167. return f"{self.collection_prefix}_{collection_name}"
  168. def _normalize_distance(self, score: float) -> float:
  169. """Normalize distance score based on the metric used."""
  170. if self.metric.lower() == "cosine":
  171. # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
  172. return (score + 1.0) / 2.0
  173. elif self.metric.lower() in ["euclidean", "dotproduct"]:
  174. # These are already suitable for ranking (smaller is better for Euclidean)
  175. return score
  176. else:
  177. # For other metrics, use as is
  178. return score
  179. def _result_to_get_result(self, matches: list) -> GetResult:
  180. """Convert Pinecone matches to GetResult format."""
  181. ids = []
  182. documents = []
  183. metadatas = []
  184. for match in matches:
  185. metadata = getattr(match, "metadata", {}) or {}
  186. ids.append(match.id if hasattr(match, "id") else match["id"])
  187. documents.append(metadata.get("text", ""))
  188. metadatas.append(metadata)
  189. return GetResult(
  190. **{
  191. "ids": [ids],
  192. "documents": [documents],
  193. "metadatas": [metadatas],
  194. }
  195. )
  196. def has_collection(self, collection_name: str) -> bool:
  197. """Check if a collection exists by searching for at least one item."""
  198. collection_name_with_prefix = self._get_collection_name_with_prefix(
  199. collection_name
  200. )
  201. try:
  202. # Search for at least 1 item with this collection name in metadata
  203. response = self.index.query(
  204. vector=[0.0] * self.dimension, # dummy vector
  205. top_k=1,
  206. filter={"collection_name": collection_name_with_prefix},
  207. include_metadata=False,
  208. )
  209. matches = getattr(response, "matches", []) or []
  210. return len(matches) > 0
  211. except Exception as e:
  212. log.exception(
  213. f"Error checking collection '{collection_name_with_prefix}': {e}"
  214. )
  215. return False
  216. def delete_collection(self, collection_name: str) -> None:
  217. """Delete a collection by removing all vectors with the collection name in metadata."""
  218. collection_name_with_prefix = self._get_collection_name_with_prefix(
  219. collection_name
  220. )
  221. try:
  222. self.index.delete(filter={"collection_name": collection_name_with_prefix})
  223. log.info(
  224. f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
  225. )
  226. except Exception as e:
  227. log.warning(
  228. f"Failed to delete collection '{collection_name_with_prefix}': {e}"
  229. )
  230. raise
  231. def insert(self, collection_name: str, items: List[VectorItem]) -> None:
  232. """Insert vectors into a collection."""
  233. if not items:
  234. log.warning("No items to insert")
  235. return
  236. start_time = time.time()
  237. collection_name_with_prefix = self._get_collection_name_with_prefix(
  238. collection_name
  239. )
  240. points = self._create_points(items, collection_name_with_prefix)
  241. # Parallelize batch inserts for performance
  242. executor = self._executor
  243. futures = []
  244. for i in range(0, len(points), BATCH_SIZE):
  245. batch = points[i : i + BATCH_SIZE]
  246. futures.append(executor.submit(self.index.upsert, vectors=batch))
  247. for future in concurrent.futures.as_completed(futures):
  248. try:
  249. future.result()
  250. except Exception as e:
  251. log.error(f"Error inserting batch: {e}")
  252. raise
  253. elapsed = time.time() - start_time
  254. log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
  255. log.info(
  256. f"Successfully inserted {len(points)} vectors in parallel batches "
  257. f"into '{collection_name_with_prefix}'"
  258. )
  259. def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
  260. """Upsert (insert or update) vectors into a collection."""
  261. if not items:
  262. log.warning("No items to upsert")
  263. return
  264. start_time = time.time()
  265. collection_name_with_prefix = self._get_collection_name_with_prefix(
  266. collection_name
  267. )
  268. points = self._create_points(items, collection_name_with_prefix)
  269. # Parallelize batch upserts for performance
  270. executor = self._executor
  271. futures = []
  272. for i in range(0, len(points), BATCH_SIZE):
  273. batch = points[i : i + BATCH_SIZE]
  274. futures.append(executor.submit(self.index.upsert, vectors=batch))
  275. for future in concurrent.futures.as_completed(futures):
  276. try:
  277. future.result()
  278. except Exception as e:
  279. log.error(f"Error upserting batch: {e}")
  280. raise
  281. elapsed = time.time() - start_time
  282. log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
  283. log.info(
  284. f"Successfully upserted {len(points)} vectors in parallel batches "
  285. f"into '{collection_name_with_prefix}'"
  286. )
  287. async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  288. """Async version of insert using asyncio and run_in_executor for improved performance."""
  289. if not items:
  290. log.warning("No items to insert")
  291. return
  292. collection_name_with_prefix = self._get_collection_name_with_prefix(
  293. collection_name
  294. )
  295. points = self._create_points(items, collection_name_with_prefix)
  296. # Create batches
  297. batches = [
  298. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  299. ]
  300. loop = asyncio.get_event_loop()
  301. tasks = [
  302. loop.run_in_executor(
  303. None, functools.partial(self.index.upsert, vectors=batch)
  304. )
  305. for batch in batches
  306. ]
  307. results = await asyncio.gather(*tasks, return_exceptions=True)
  308. for result in results:
  309. if isinstance(result, Exception):
  310. log.error(f"Error in async insert batch: {result}")
  311. raise result
  312. log.info(
  313. f"Successfully async inserted {len(points)} vectors in batches "
  314. f"into '{collection_name_with_prefix}'"
  315. )
  316. async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
  317. """Async version of upsert using asyncio and run_in_executor for improved performance."""
  318. if not items:
  319. log.warning("No items to upsert")
  320. return
  321. collection_name_with_prefix = self._get_collection_name_with_prefix(
  322. collection_name
  323. )
  324. points = self._create_points(items, collection_name_with_prefix)
  325. # Create batches
  326. batches = [
  327. points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
  328. ]
  329. loop = asyncio.get_event_loop()
  330. tasks = [
  331. loop.run_in_executor(
  332. None, functools.partial(self.index.upsert, vectors=batch)
  333. )
  334. for batch in batches
  335. ]
  336. results = await asyncio.gather(*tasks, return_exceptions=True)
  337. for result in results:
  338. if isinstance(result, Exception):
  339. log.error(f"Error in async upsert batch: {result}")
  340. raise result
  341. log.info(
  342. f"Successfully async upserted {len(points)} vectors in batches "
  343. f"into '{collection_name_with_prefix}'"
  344. )
  345. def search(
  346. self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
  347. ) -> Optional[SearchResult]:
  348. """Search for similar vectors in a collection."""
  349. if not vectors or not vectors[0]:
  350. log.warning("No vectors provided for search")
  351. return None
  352. collection_name_with_prefix = self._get_collection_name_with_prefix(
  353. collection_name
  354. )
  355. if limit is None or limit <= 0:
  356. limit = NO_LIMIT
  357. try:
  358. # Search using the first vector (assuming this is the intended behavior)
  359. query_vector = vectors[0]
  360. # Perform the search
  361. query_response = self.index.query(
  362. vector=query_vector,
  363. top_k=limit,
  364. include_metadata=True,
  365. filter={"collection_name": collection_name_with_prefix},
  366. )
  367. matches = getattr(query_response, "matches", []) or []
  368. if not matches:
  369. # Return empty result if no matches
  370. return SearchResult(
  371. ids=[[]],
  372. documents=[[]],
  373. metadatas=[[]],
  374. distances=[[]],
  375. )
  376. # Convert to GetResult format
  377. get_result = self._result_to_get_result(matches)
  378. # Calculate normalized distances based on metric
  379. distances = [
  380. [
  381. self._normalize_distance(getattr(match, "score", 0.0))
  382. for match in matches
  383. ]
  384. ]
  385. return SearchResult(
  386. ids=get_result.ids,
  387. documents=get_result.documents,
  388. metadatas=get_result.metadatas,
  389. distances=distances,
  390. )
  391. except Exception as e:
  392. log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
  393. return None
  394. def query(
  395. self, collection_name: str, filter: Dict, limit: Optional[int] = None
  396. ) -> Optional[GetResult]:
  397. """Query vectors by metadata filter."""
  398. collection_name_with_prefix = self._get_collection_name_with_prefix(
  399. collection_name
  400. )
  401. if limit is None or limit <= 0:
  402. limit = NO_LIMIT
  403. try:
  404. # Create a zero vector for the dimension as Pinecone requires a vector
  405. zero_vector = [0.0] * self.dimension
  406. # Combine user filter with collection_name
  407. pinecone_filter = {"collection_name": collection_name_with_prefix}
  408. if filter:
  409. pinecone_filter.update(filter)
  410. # Perform metadata-only query
  411. query_response = self.index.query(
  412. vector=zero_vector,
  413. filter=pinecone_filter,
  414. top_k=limit,
  415. include_metadata=True,
  416. )
  417. matches = getattr(query_response, "matches", []) or []
  418. return self._result_to_get_result(matches)
  419. except Exception as e:
  420. log.error(f"Error querying collection '{collection_name}': {e}")
  421. return None
  422. def get(self, collection_name: str) -> Optional[GetResult]:
  423. """Get all vectors in a collection."""
  424. collection_name_with_prefix = self._get_collection_name_with_prefix(
  425. collection_name
  426. )
  427. try:
  428. # Use a zero vector for fetching all entries
  429. zero_vector = [0.0] * self.dimension
  430. # Add filter to only get vectors for this collection
  431. query_response = self.index.query(
  432. vector=zero_vector,
  433. top_k=NO_LIMIT,
  434. include_metadata=True,
  435. filter={"collection_name": collection_name_with_prefix},
  436. )
  437. matches = getattr(query_response, "matches", []) or []
  438. return self._result_to_get_result(matches)
  439. except Exception as e:
  440. log.error(f"Error getting collection '{collection_name}': {e}")
  441. return None
  442. def delete(
  443. self,
  444. collection_name: str,
  445. ids: Optional[List[str]] = None,
  446. filter: Optional[Dict] = None,
  447. ) -> None:
  448. """Delete vectors by IDs or filter."""
  449. collection_name_with_prefix = self._get_collection_name_with_prefix(
  450. collection_name
  451. )
  452. try:
  453. if ids:
  454. # Delete by IDs (in batches for large deletions)
  455. for i in range(0, len(ids), BATCH_SIZE):
  456. batch_ids = ids[i : i + BATCH_SIZE]
  457. # Note: When deleting by ID, we can't filter by collection_name
  458. # This is a limitation of Pinecone - be careful with ID uniqueness
  459. self.index.delete(ids=batch_ids)
  460. log.debug(
  461. f"Deleted batch of {len(batch_ids)} vectors by ID "
  462. f"from '{collection_name_with_prefix}'"
  463. )
  464. log.info(
  465. f"Successfully deleted {len(ids)} vectors by ID "
  466. f"from '{collection_name_with_prefix}'"
  467. )
  468. elif filter:
  469. # Combine user filter with collection_name
  470. pinecone_filter = {"collection_name": collection_name_with_prefix}
  471. if filter:
  472. pinecone_filter.update(filter)
  473. # Delete by metadata filter
  474. self.index.delete(filter=pinecone_filter)
  475. log.info(
  476. f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
  477. )
  478. else:
  479. log.warning("No ids or filter provided for delete operation")
  480. except Exception as e:
  481. log.error(f"Error deleting from collection '{collection_name}': {e}")
  482. raise
  483. def reset(self) -> None:
  484. """Reset the database by deleting all collections."""
  485. try:
  486. self.index.delete(delete_all=True)
  487. log.info("All vectors successfully deleted from the index.")
  488. except Exception as e:
  489. log.error(f"Failed to reset Pinecone index: {e}")
  490. raise
  491. def close(self):
  492. """Shut down resources."""
  493. try:
  494. # The new Pinecone client doesn't need explicit closing
  495. pass
  496. except Exception as e:
  497. log.warning(f"Failed to clean up Pinecone resources: {e}")
  498. self._executor.shutdown(wait=True)
  499. def __enter__(self):
  500. """Enter context manager."""
  501. return self
  502. def __exit__(self, exc_type, exc_val, exc_tb):
  503. """Exit context manager, ensuring resources are cleaned up."""
  504. self.close()